Merge remote-tracking branch 'upstream/release-v1.30.0'

This commit is contained in:
Tulir Asokan 2021-03-16 16:58:38 +02:00
commit 8b753230af
102 changed files with 2693 additions and 933 deletions

8
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,8 @@
# Black reformatting (#5482).
32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
# Target Python 3.5 with black (#8664).
aff1eb7c671b0a3813407321d2702ec46c71fa56
# Update black to 20.8b1 (#9381).
0a00b7ff14890987f09112a2ae696c61001e6cf1

View file

@ -1,3 +1,80 @@
Synapse 1.30.0rc1 (2021-03-16)
==============================
Note that this release deprecates the ability for appservices to
call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice
developers should use a `type` value of `m.login.application_service` as
per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions).
In future releases, calling this endpoint with an access token - but without a `m.login.application_service`
type - will fail.
Features
--------
- Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573))
- Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
- Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549))
- Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601))
- Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617))
- Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626))
Bugfixes
--------
- Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473))
- Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583))
- Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567))
- Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587))
- Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597))
- Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620))
- Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623))
Updates to the Docker image
---------------------------
- Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553))
Improved Documentation
----------------------
- Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508))
- Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550))
- Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571))
- Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580))
- Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
Deprecations and Removals
-------------------------
- The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
- Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559))
Internal Changes
----------------
- Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458))
- Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520))
- Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523))
- Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618))
- Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541))
- Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560))
- Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561))
- Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562))
- Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563))
- Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568))
- Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576))
- Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586))
- Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590))
- Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596))
- Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
- Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619))
Synapse 1.29.0 (2021-03-08)
===========================

View file

@ -20,9 +20,10 @@ recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py
include tests/http/ca.crt
include tests/http/ca.key
include tests/http/server.key
recursive-include tests *.pem
recursive-include tests *.p8
recursive-include tests *.crt
recursive-include tests *.key
recursive-include synapse/res *
recursive-include synapse/static *.css

View file

@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
`HAProxy <https://www.haproxy.org/>`_ or
`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges.

View file

@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p
need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
"ACS location" (also known as "allowed callback URLs") at the identity provider.
The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
provider uses this property to validate or otherwise identify Synapse, its configuration
will need to be updated to use the new URL. Alternatively you could create a new, separate
"EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
the existing "EntityDescriptor" as they were.
Changes to HTML templates
-------------------------

View file

@ -69,6 +69,7 @@ RUN apt-get update && apt-get install -y \
libpq5 \
libwebp6 \
xmlsec1 \
libjemalloc2 \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local

View file

@ -204,3 +204,8 @@ healthcheck:
timeout: 10s
retries: 3
```
## Using jemalloc
Jemalloc is embedded in the image and will be used instead of the default allocator.
You can read about jemalloc by reading the Synapse [README](../README.md)

View file

@ -3,6 +3,7 @@
import codecs
import glob
import os
import platform
import subprocess
import sys
@ -213,6 +214,13 @@ def main(args, environ):
if "-m" not in args:
args = ["-m", synapse_worker] + args
jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
if os.path.isfile(jemallocpath):
environ["LD_PRELOAD"] = jemallocpath
else:
log("Could not find %s, will not use" % (jemallocpath,))
# if there are no config files passed to synapse, try adding the default file
if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
@ -248,9 +256,9 @@ running with 'migrate_config'. See the README for more details.
args = ["python"] + args
if ownership is not None:
args = ["gosu", ownership] + args
os.execv("/usr/sbin/gosu", args)
os.execve("/usr/sbin/gosu", args, environ)
else:
os.execv("/usr/local/bin/python", args)
os.execve("/usr/local/bin/python", args, environ)
if __name__ == "__main__":

View file

@ -1,5 +1,7 @@
# Contents
- [List all media in a room](#list-all-media-in-a-room)
- [Querying media](#querying-media)
* [List all media in a room](#list-all-media-in-a-room)
* [List all media uploaded by a user](#list-all-media-uploaded-by-a-user)
- [Quarantine media](#quarantine-media)
* [Quarantining media by ID](#quarantining-media-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room)
@ -10,7 +12,11 @@
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
- [Purge Remote Media API](#purge-remote-media-api)
# List all media in a room
# Querying media
These APIs allow extracting media information from the homeserver.
## List all media in a room
This API gets a list of known media in a room.
However, it only shows media from unencrypted events or rooms.
@ -36,6 +42,12 @@ The API returns a JSON body like the following:
}
```
## List all media uploaded by a user
Listing all media that has been uploaded by a local user can be achieved through
the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user)
Admin API.
# Quarantine media
Quarantining media means that it is marked as inaccessible by users. It applies

View file

@ -226,7 +226,7 @@ Synapse config:
oidc_providers:
- idp_id: github
idp_name: Github
idp_brand: "org.matrix.github" # optional: styling hint for clients
idp_brand: "github" # optional: styling hint for clients
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
@ -252,7 +252,7 @@ oidc_providers:
oidc_providers:
- idp_id: google
idp_name: Google
idp_brand: "org.matrix.google" # optional: styling hint for clients
idp_brand: "google" # optional: styling hint for clients
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@ -299,7 +299,7 @@ Synapse config:
oidc_providers:
- idp_id: gitlab
idp_name: Gitlab
idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
idp_brand: "gitlab" # optional: styling hint for clients
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@ -334,7 +334,7 @@ Synapse config:
```yaml
- idp_id: facebook
idp_name: Facebook
idp_brand: "org.matrix.facebook" # optional: styling hint for clients
idp_brand: "facebook" # optional: styling hint for clients
discover: false
issuer: "https://facebook.com"
client_id: "your-client-id" # TO BE FILLED
@ -386,7 +386,7 @@ oidc_providers:
config:
subject_claim: "id"
localpart_template: "{{ user.login }}"
display_name_template: "{{ user.full_name }}"
display_name_template: "{{ user.full_name }}"
```
### XWiki
@ -401,8 +401,7 @@ oidc_providers:
idp_name: "XWiki"
issuer: "https://myxwikihost/xwiki/oidc/"
client_id: "your-client-id" # TO BE FILLED
# Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
client_secret: "dontcare"
client_auth_method: none
scopes: ["openid", "profile"]
user_profile_method: "userinfo_endpoint"
user_mapping_provider:
@ -410,3 +409,40 @@ oidc_providers:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
## Apple
Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
You will need to create a new "Services ID" for SiWA, and create and download a
private key with "SiWA" enabled.
As well as the private key file, you will need:
* Client ID: the "identifier" you gave the "Services ID"
* Team ID: a 10-character ID associated with your developer account.
* Key ID: the 10-character identifier for the key.
https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
documentation on setting up SiWA.
The synapse config will look like this:
```yaml
- idp_id: apple
idp_name: Apple
issuer: "https://appleid.apple.com"
client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
client_auth_method: "client_secret_post"
client_secret_jwt_key:
key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
jwt_header:
alg: ES256
kid: "KEYIDCODE" # Set to the 10-char Key ID
jwt_payload:
iss: TEAMIDCODE # Set to the 10-char Team ID
scopes: ["name", "email", "openid"]
authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
user_mapping_provider:
config:
email_template: "{{ user.email }}"
```

View file

@ -3,8 +3,9 @@
It is recommended to put a reverse proxy such as
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
[HAProxy](https://www.haproxy.org/) or
[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root
privileges.
@ -162,6 +163,52 @@ backend matrix
server matrix 127.0.0.1:8008
```
### Relayd
```
table <webserver> { 127.0.0.1 }
table <matrixserver> { 127.0.0.1 }
http protocol "https" {
tls { no tlsv1.0, ciphers "HIGH" }
tls keypair "example.com"
match header set "X-Forwarded-For" value "$REMOTE_ADDR"
match header set "X-Forwarded-Proto" value "https"
# set CORS header for .well-known/matrix/server, .well-known/matrix/client
# httpd does not support setting headers, so do it here
match request path "/.well-known/matrix/*" tag "matrix-cors"
match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
pass quick path "/_matrix/*" forward to <matrixserver>
pass quick path "/_synapse/client/*" forward to <matrixserver>
# pass on non-matrix traffic to webserver
pass forward to <webserver>
}
relay "https_traffic" {
listen on egress port 443 tls
protocol "https"
forward to <matrixserver> port 8008 check tcp
forward to <webserver> port 8080 check tcp
}
http protocol "matrix" {
tls { no tlsv1.0, ciphers "HIGH" }
tls keypair "example.com"
block
pass quick path "/_matrix/*" forward to <matrixserver>
pass quick path "/_synapse/client/*" forward to <matrixserver>
}
relay "matrix_federation" {
listen on egress port 8448 tls
protocol "matrix"
forward to <matrixserver> port 8008 check tcp
}
```
## Homeserver Configuration
You will also want to set `bind_addresses: ['127.0.0.1']` and

View file

@ -89,8 +89,7 @@ pid_file: DATADIR/homeserver.pid
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
# API, so this setting is of limited value if federation is enabled on
# the server.
# API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
@ -1780,7 +1779,26 @@ saml2_config:
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
# client_secret: oauth2 client secret to use. May be omitted if
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
#
# client_secret_jwt_key: Alternative to client_secret: details of a key used
# to create a JSON Web Token to be used as an OAuth2 client secret. If
# given, must be a dictionary with the following properties:
#
# key: a pem-encoded signing key. Must be a suitable key for the
# algorithm specified. Required unless 'key_file' is given.
#
# key_file: the path to file containing a pem-encoded signing key file.
# Required unless 'key' is given.
#
# jwt_header: a dictionary giving properties to include in the JWT
# header. Must include the key 'alg', giving the algorithm used to
# sign the JWT, such as "ES256", using the JWA identifiers in
# RFC7518.
#
# jwt_payload: an optional dictionary giving properties to include in
# the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
@ -1901,7 +1919,7 @@ oidc_providers:
#
#- idp_id: github
# idp_name: Github
# idp_brand: org.matrix.github
# idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@ -2627,19 +2645,20 @@ user_directory:
# Local statistics collection. Used in populating the room directory.
# Settings for local room and user statistics collection. See
# docs/room_and_user_statistics.md.
#
# 'bucket_size' controls how large each statistics timeslice is. It can
# be defined in a human readable short form -- e.g. "1d", "1y".
#
# 'retention' controls how long historical statistics will be kept for.
# It can be defined in a human readable short form -- e.g. "1d", "1y".
#
#
#stats:
# enabled: true
# bucket_size: 1d
# retention: 1y
stats:
# Uncomment the following to disable room and user statistics. Note that doing
# so may cause certain features (such as the room directory) not to work
# correctly.
#
#enabled: false
# The size of each timeslice in the room_stats_historical and
# user_stats_historical tables, as a time period. Defaults to "1d".
#
#bucket_size: 1h
# Server Notices room configuration

View file

@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
All the methods must be defined.
There's a generic method for checking every event (`check_event_for_spam`), as
well as some specific methods:
@ -24,6 +25,7 @@ well as some specific methods:
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
* `check_media_file_for_spam`
The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to
call back into the homeserver internals.
Additionally, a `parse_config` method is mandatory and receives the plugin config
dictionary. After parsing, It must return an object which will be
passed to `__init__` later.
### Example
```python
@ -41,6 +47,10 @@ class ExampleSpamChecker:
self.config = config
self.api = api
@staticmethod
def parse_config(config):
return config
async def check_event_for_spam(self, foo):
return False # allow all events
@ -59,7 +69,13 @@ class ExampleSpamChecker:
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
async def check_registration_for_spam(self, email_threepid, username, request_info):
async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
auth_provider_id,
):
return RegistrationBehaviour.ALLOW # allow all registrations
async def check_media_file_for_spam(self, file_wrapper, file_info):

View file

@ -69,6 +69,7 @@ files =
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/stringutils.py,
tests/replication,
tests/test_utils,
@ -116,9 +117,6 @@ ignore_missing_imports = True
[mypy-saml2.*]
ignore_missing_imports = True
[mypy-unpaddedbase64]
ignore_missing_imports = True
[mypy-canonicaljson]
ignore_missing_imports = True

View file

@ -2,9 +2,14 @@
# Find linting errors in Synapse's default config file.
# Exits with 0 if there are no problems, or another code otherwise.
# cd to the root of the repository
cd `dirname $0`/..
# Restore backup of sample config upon script exit
trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
# Fix non-lowercase true/false values
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
rm docs/sample_config.yaml.bak
# Check if anything changed
git diff --exit-code docs/sample_config.yaml
diff docs/sample_config.yaml docs/sample_config.yaml.bak

View file

@ -3,6 +3,7 @@ test_suite = tests
[check-manifest]
ignore =
.git-blame-ignore-revs
contrib
contrib/*
docs/*

View file

@ -17,7 +17,9 @@
"""
from typing import Any, List, Optional, Type, Union
class RedisProtocol:
from twisted.internet import protocol
class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
@ -52,7 +54,7 @@ def lazyConnection(
class ConnectionHandler: ...
class RedisFactory:
class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]

View file

@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.29.0"
__version__ = "1.30.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -163,7 +164,7 @@ class Auth:
async def get_user_by_req(
self,
request: Request,
request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
@ -408,7 +409,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
user_id = self.get_user_id_from_macaroon(macaroon)
user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@ -416,7 +417,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
except (
pymacaroons.exceptions.MacaroonException,
KeyError,
TypeError,
ValueError,
):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@ -424,27 +430,6 @@ class Auth:
return user_id, guest
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
InvalidClientCredentialsError if there is no user_id caveat in the
macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix) :]
raise InvalidClientTokenError("No user caveat in macaroon")
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@ -465,21 +450,13 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
v.satisfy_general(self._verify_expiry)
satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.hs.get_clock().time_msec()
return now < expiry
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)

View file

@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):

View file

@ -212,9 +212,8 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
with open(file_path) as file_stream:
return file_stream.read()
"""Deprecated: call read_file directly"""
return read_file(file_path, (config_name,))
def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk.
@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
return self._get_instance(key)
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
def read_file(file_path: Any, config_path: Iterable[str]) -> str:
"""Check the given file exists, and read it into a string
If it does not, emit an error indicating the problem
Args:
file_path: the file to be read
config_path: where in the configuration file_path came from, so that a useful
error can be emitted if it does not exist.
Returns:
content of the file.
Raises:
ConfigError if there is a problem reading the file.
"""
if not isinstance(file_path, str):
raise ConfigError("%r is not a string", config_path)
try:
os.stat(file_path)
with open(file_path) as file_stream:
return file_stream.read()
except OSError as e:
raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
__all__ = [
"Config",
"RootConfig",
"ShardedWorkerHandlingConfig",
"RoutableShardedWorkerHandlingConfig",
"read_file",
]

View file

@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...
def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...

View file

@ -21,8 +21,10 @@ import threading
from string import Template
import yaml
from zope.interface import implementer
from twisted.logger import (
ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
threadlocal = threading.local()
def _log(event):
@implementer(ILogObserver)
def _log(event: dict) -> None:
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return

View file

@ -15,7 +15,7 @@
# limitations under the License.
from collections import Counter
from typing import Iterable, Optional, Tuple, Type
from typing import Iterable, Mapping, Optional, Tuple, Type
import attr
@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
from ._base import Config, ConfigError
from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
@ -97,7 +97,26 @@ class OIDCConfig(Config):
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
# client_secret: oauth2 client secret to use. May be omitted if
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
#
# client_secret_jwt_key: Alternative to client_secret: details of a key used
# to create a JSON Web Token to be used as an OAuth2 client secret. If
# given, must be a dictionary with the following properties:
#
# key: a pem-encoded signing key. Must be a suitable key for the
# algorithm specified. Required unless 'key_file' is given.
#
# key_file: the path to file containing a pem-encoded signing key file.
# Required unless 'key' is given.
#
# jwt_header: a dictionary giving properties to include in the JWT
# header. Must include the key 'alg', giving the algorithm used to
# sign the JWT, such as "ES256", using the JWA identifiers in
# RFC7518.
#
# jwt_payload: an optional dictionary giving properties to include in
# the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
@ -218,7 +237,7 @@ class OIDCConfig(Config):
#
#- idp_id: github
# idp_name: Github
# idp_brand: org.matrix.github
# idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@ -240,7 +259,7 @@ class OIDCConfig(Config):
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
"required": ["issuer", "client_id", "client_secret"],
"required": ["issuer", "client_id"],
"properties": {
"idp_id": {
"type": "string",
@ -253,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"idp_icon": {"type": "string"},
"idp_brand": {
"type": "string",
# MSC2758-style namespaced identifier
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
},
"idp_unstable_brand": {
"type": "string",
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
@ -262,6 +286,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
"client_secret_jwt_key": {
"type": "object",
"required": ["jwt_header"],
"oneOf": [
{"required": ["key"]},
{"required": ["key_file"]},
],
"properties": {
"key": {"type": "string"},
"key_file": {"type": "string"},
"jwt_header": {
"type": "object",
"required": ["alg"],
"properties": {
"alg": {"type": "string"},
},
"additionalProperties": {"type": "string"},
},
"jwt_payload": {
"type": "object",
"additionalProperties": {"type": "string"},
},
},
},
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
@ -404,15 +452,31 @@ def _parse_oidc_config_dict(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile:
key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
else:
key = client_secret_jwt_key_config["key"]
client_secret_jwt_key = OidcProviderClientSecretJwtKey(
key=key,
jwt_header=client_secret_jwt_key_config["jwt_header"],
jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
)
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"),
unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"],
client_secret=oidc_config.get("client_secret"),
client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
@ -427,6 +491,18 @@ def _parse_oidc_config_dict(
)
@attr.s(slots=True, frozen=True)
class OidcProviderClientSecretJwtKey:
# a pem-encoded signing key
key = attr.ib(type=str)
# properties to include in the JWT header
jwt_header = attr.ib(type=Mapping[str, str])
# properties to include in the JWT payload.
jwt_payload = attr.ib(type=Mapping[str, str])
@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
@ -442,6 +518,9 @@ class OidcProviderConfig:
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])
# Optional brand identifier for the unstable API (see MSC2858).
unstable_idp_brand = attr.ib(type=Optional[str])
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
@ -452,8 +531,13 @@ class OidcProviderConfig:
# oauth2 client id to use
client_id = attr.ib(type=str)
# oauth2 client secret to use
client_secret = attr.ib(type=str)
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret.
client_secret = attr.ib(type=Optional[str])
# key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set.
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and

View file

@ -841,8 +841,7 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
# API, so this setting is of limited value if federation is enabled on
# the server.
# API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true

View file

@ -13,10 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import logging
from ._base import Config
ROOM_STATS_DISABLED_WARN = """\
WARNING: room/user statistics have been disabled via the stats.enabled
configuration setting. This means that certain features (such as the room
directory) will not operate correctly. Future versions of Synapse may ignore
this setting.
To fix this warning, remove the stats.enabled setting from your configuration
file.
--------------------------------------------------------------------------------"""
logger = logging.getLogger(__name__)
class StatsConfig(Config):
"""Stats Configuration
@ -28,30 +40,29 @@ class StatsConfig(Config):
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000
self.stats_retention = sys.maxsize
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
self.stats_bucket_size = self.parse_duration(
stats_config.get("bucket_size", "1d")
)
self.stats_retention = self.parse_duration(
stats_config.get("retention", "%ds" % (sys.maxsize,))
)
if not self.stats_enabled:
logger.warning(ROOM_STATS_DISABLED_WARN)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Local statistics collection. Used in populating the room directory.
# Settings for local room and user statistics collection. See
# docs/room_and_user_statistics.md.
#
# 'bucket_size' controls how large each statistics timeslice is. It can
# be defined in a human readable short form -- e.g. "1d", "1y".
#
# 'retention' controls how long historical statistics will be kept for.
# It can be defined in a human readable short form -- e.g. "1d", "1y".
#
#
#stats:
# enabled: true
# bucket_size: 1d
# retention: 1y
stats:
# Uncomment the following to disable room and user statistics. Note that doing
# so may cause certain features (such as the room directory) not to work
# correctly.
#
#enabled: false
# The size of each timeslice in the room_stats_historical and
# user_stats_historical tables, as a time period. Defaults to "1d".
#
#bucket_size: 1h
"""

View file

@ -15,6 +15,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.rest.media.v1._base import FileInfo
@ -27,6 +28,8 @@ if TYPE_CHECKING:
import synapse.events
import synapse.server
logger = logging.getLogger(__name__)
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
@ -190,6 +193,7 @@ class SpamChecker:
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
@ -198,6 +202,9 @@ class SpamChecker:
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
"cas". If any. Note this does not include users registered
via a password provider.
Returns:
Enum for how the request should be handled
@ -208,9 +215,25 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
behaviour = await maybe_awaitable(
checker(email_threepid, username, request_info)
)
# Provide auth_provider_id if the function supports it
checker_args = inspect.signature(checker)
if len(checker_args.parameters) == 4:
d = checker(
email_threepid,
username,
request_info,
auth_provider_id,
)
elif len(checker_args.parameters) == 3:
d = checker(email_threepid, username, request_info)
else:
logger.error(
"Invalid signature for %s.check_registration_for_spam. Denying registration",
spam_checker.__module__,
)
return RegistrationBehaviour.DENY
behaviour = await maybe_awaitable(d)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour

View file

@ -22,6 +22,7 @@ from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
@ -90,16 +91,15 @@ pdu_process_time = Histogram(
"Time taken to process an event",
)
last_pdu_age_metric = Gauge(
"synapse_federation_last_received_pdu_age",
"The age (in seconds) of the last PDU successfully received from the given domain",
last_pdu_ts_metric = Gauge(
"synapse_federation_last_received_pdu_time",
"The timestamp of the last PDU which was successfully received from the given domain",
labelnames=("server_name",),
)
class FederationServer(FederationBase):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
@ -112,14 +112,15 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
self._federation_ratelimiter = hs.get_federation_ratelimiter()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
# origins that we are currently processing a transaction from.
# a dict from origin to txn id.
self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
hs, "fed_txn_handler", timeout_ms=30000
hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@ -129,10 +130,10 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(
hs, "state_resp", timeout_ms=30000
hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
hs, "state_ids_resp", timeout_ms=30000
hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
@ -169,6 +170,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
# Reject malformed transactions early: reject if too many PDUs/EDUs
if len(transaction.pdus) > 50 or ( # type: ignore
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400")
return 400, {}
# we only process one transaction from each origin at a time. We need to do
# this check here, rather than in _on_incoming_transaction_inner so that we
# don't cache the rejection in _transaction_resp_cache (so that if the txn
# arrives again later, we can process it).
current_transaction = self._active_transactions.get(origin)
if current_transaction and current_transaction != transaction_id:
logger.warning(
"Received another txn %s from %s while still processing %s",
transaction_id,
origin,
current_transaction,
)
return 429, {
"errcode": Codes.UNKNOWN,
"error": "Too many concurrent transactions",
}
# CRITICAL SECTION: we must now not await until we populate _active_transactions
# in _on_incoming_transaction_inner.
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@ -182,26 +210,18 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
# Use a linearizer to ensure that transactions from a remote are
# processed in order.
with await self._transaction_linearizer.queue(origin):
# We rate limit here *after* we've queued up the incoming requests,
# so that we don't fill up the ratelimiter with blocked requests.
#
# This is important as the ratelimiter allows N concurrent requests
# at a time, and only starts ratelimiting if there are more requests
# than that being processed at a time. If we queued up requests in
# the linearizer/response cache *after* the ratelimiting then those
# queued up requests would count as part of the allowed limit of N
# concurrent requests.
with self._federation_ratelimiter.ratelimit(origin) as d:
await d
# CRITICAL SECTION: the first thing we must do (before awaiting) is
# add an entry to _active_transactions.
assert origin not in self._active_transactions
self._active_transactions[origin] = transaction.transaction_id # type: ignore
result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
try:
result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
finally:
del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
@ -227,19 +247,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
# Reject if PDU count > 50 or EDU count > 100
if len(transaction.pdus) > 50 or ( # type: ignore
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
await self.transaction_actions.set_response(
origin, transaction, 400, response
)
return 400, response
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@ -335,42 +342,48 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
logger.debug("Processing PDUs for %s", room_id)
try:
await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
return
with nested_logging_context(room_id):
logger.debug("Processing PDUs for %s", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
with pdu_process_time.time():
with nested_logging_context(event_id):
try:
await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
f = failure.Failure()
pdu_results[event_id] = {"error": str(e)}
logger.error(
"Failed to handle PDU %s",
event_id,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
try:
await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
logger.warning(
"Ignoring PDUs for room %s from banned server", room_id
)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
return
for pdu in pdus_by_room[room_id]:
pdu_results[pdu.event_id] = await process_pdu(pdu)
async def process_pdu(pdu: EventBase) -> JsonDict:
event_id = pdu.event_id
with pdu_process_time.time():
with nested_logging_context(event_id):
try:
await self._handle_received_pdu(origin, pdu)
return {}
except FederationError as e:
logger.warning("Error handling PDU %s: %s", event_id, e)
return {"error": str(e)}
except Exception as e:
f = failure.Failure()
logger.error(
"Failed to handle PDU %s",
event_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if newest_pdu_ts and origin in self._federation_metrics_domains:
newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
return pdu_results
@ -448,18 +461,22 @@ class FederationServer(FederationBase):
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
pdus = await self.handler.get_state_for_pdu(
room_id, event_id
) # type: Iterable[EventBase]
else:
pdus = (await self.state.get_current_state(room_id)).values()
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
auth_chain = await self.store.get_auth_chain(
room_id, [pdu.event_id for pdu in pdus]
)
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
@ -863,7 +880,9 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
self.query_handlers = (
{}
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
@ -897,7 +916,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
self, query_type: str, handler: Callable[[dict], defer.Deferred]
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@ -970,7 +989,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict):
async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)

View file

@ -17,6 +17,7 @@ import datetime
import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
import attr
from prometheus_client import Counter
from synapse.api.errors import (
@ -93,6 +94,10 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
# Flag to signal to any running transmission loop that there is new data
# queued up to be sent.
self._new_data_to_send = False
# True whilst we are sending events that the remote homeserver missed
# because it was unreachable. We start in this state so we can perform
# catch-up at startup.
@ -108,7 +113,7 @@ class PerDestinationQueue:
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int]
# a list of pending PDUs
# a queue of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see
@ -208,6 +213,10 @@ class PerDestinationQueue:
transaction in the background.
"""
# Mark that we (may) have new things to send, so that any running
# transmission loop will recheck whether there is stuff to send.
self._new_data_to_send = True
if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing.
@ -250,125 +259,41 @@ class PerDestinationQueue:
pending_pdus = []
while True:
# We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2
self._new_data_to_send = False
device_update_edus, dev_list_id = await self._get_device_update_edus(
limit
)
limit -= len(device_update_edus)
(
to_device_edus,
device_stream_id,
) = await self._get_to_device_message_edus(limit)
pending_edus = device_update_edus + to_device_edus
# BEGIN CRITICAL SECTION
#
# In order to avoid a race condition, we need to make sure that
# the following code (from popping the queues up to the point
# where we decide if we actually have any pending messages) is
# atomic - otherwise new PDUs or EDUs might arrive in the
# meantime, but not get sent because we hold the
# transmission_loop_running flag.
pending_pdus = self._pending_pdus
# We can only include at most 50 PDUs per transactions
pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
pending_edus.extend(self._get_rr_edus(force_flush=False))
pending_presence = self._pending_presence
self._pending_presence = {}
if pending_presence:
pending_edus.append(
Edu(
origin=self._server_name,
destination=self._destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self._clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
pending_edus.extend(
self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
)
while (
len(pending_edus) < MAX_EDUS_PER_TRANSACTION
and self._pending_edus_keyed
async with _TransactionQueueManager(self) as (
pending_pdus,
pending_edus,
):
_, val = self._pending_edus_keyed.popitem()
pending_edus.append(val)
if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", self._destination)
if pending_pdus:
logger.debug(
"TX [%s] len(pending_pdus_by_dest[dest]) = %d",
self._destination,
len(pending_pdus),
# If we've gotten told about new things to send during
# checking for things to send, we try looking again.
# Otherwise new PDUs or EDUs might arrive in the meantime,
# but not get sent because we hold the
# `transmission_loop_running` flag.
if self._new_data_to_send:
continue
else:
return
if pending_pdus:
logger.debug(
"TX [%s] len(pending_pdus_by_dest[dest]) = %d",
self._destination,
len(pending_pdus),
)
await self._transaction_manager.send_new_transaction(
self._destination, pending_pdus, pending_edus
)
if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", self._destination)
self._last_device_stream_id = device_stream_id
return
# if we've decided to send a transaction anyway, and we have room, we
# may as well send any pending RRs
if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
pending_edus.extend(self._get_rr_edus(force_flush=True))
# END CRITICAL SECTION
success = await self._transaction_manager.send_new_transaction(
self._destination, pending_pdus, pending_edus
)
if success:
sent_transactions_counter.inc()
sent_edus_counter.inc(len(pending_edus))
for edu in pending_edus:
sent_edus_by_type.labels(edu.edu_type).inc()
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if to_device_edus:
await self._store.delete_device_msgs_for_remote(
self._destination, device_stream_id
)
# also mark the device updates as sent
if device_update_edus:
logger.info(
"Marking as sent %r %r", self._destination, dev_list_id
)
await self._store.mark_as_sent_devices_by_remote(
self._destination, dev_list_id
)
self._last_device_stream_id = device_stream_id
self._last_device_list_stream_id = dev_list_id
if pending_pdus:
# we sent some PDUs and it was successful, so update our
# last_successful_stream_ordering in the destinations table.
final_pdu = pending_pdus[-1]
last_successful_stream_ordering = (
final_pdu.internal_metadata.stream_ordering
)
assert last_successful_stream_ordering
await self._store.set_destination_last_successful_stream_ordering(
self._destination, last_successful_stream_ordering
)
else:
break
except NotRetryingDestination as e:
logger.debug(
"TX [%s] not ready for retry yet (next retry at %s) - "
@ -401,7 +326,7 @@ class PerDestinationQueue:
self._pending_presence = {}
self._pending_rrs = {}
self._start_catching_up()
self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@ -412,7 +337,6 @@ class PerDestinationQueue:
e,
)
self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@ -422,16 +346,12 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
@ -499,13 +419,10 @@ class PerDestinationQueue:
rooms = [p.room_id for p in catchup_pdus]
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
success = await self._transaction_manager.send_new_transaction(
await self._transaction_manager.send_new_transaction(
self._destination, catchup_pdus, []
)
if not success:
return
sent_transactions_counter.inc()
final_pdu = catchup_pdus[-1]
self._last_successful_stream_ordering = cast(
@ -584,3 +501,135 @@ class PerDestinationQueue:
"""
self._catching_up = True
self._pending_pdus = []
@attr.s(slots=True)
class _TransactionQueueManager:
"""A helper async context manager for pulling stuff off the queues and
tracking what was last successfully sent, etc.
"""
queue = attr.ib(type=PerDestinationQueue)
_device_stream_id = attr.ib(type=Optional[int], default=None)
_device_list_id = attr.ib(type=Optional[int], default=None)
_last_stream_ordering = attr.ib(type=Optional[int], default=None)
_pdus = attr.ib(type=List[EventBase], factory=list)
async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
# First we calculate the EDUs we want to send, if any.
# We start by fetching device related EDUs, i.e device updates and to
# device messages. We have to keep 2 free slots for presence and rr_edus.
limit = MAX_EDUS_PER_TRANSACTION - 2
device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
limit
)
if device_update_edus:
self._device_list_id = dev_list_id
else:
self.queue._last_device_list_stream_id = dev_list_id
limit -= len(device_update_edus)
(
to_device_edus,
device_stream_id,
) = await self.queue._get_to_device_message_edus(limit)
if to_device_edus:
self._device_stream_id = device_stream_id
else:
self.queue._last_device_stream_id = device_stream_id
pending_edus = device_update_edus + to_device_edus
# Now add the read receipt EDU.
pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
# And presence EDU.
if self.queue._pending_presence:
pending_edus.append(
Edu(
origin=self.queue._server_name,
destination=self.queue._destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.queue._clock.time_msec()
)
for presence in self.queue._pending_presence.values()
]
},
)
)
self.queue._pending_presence = {}
# Finally add any other types of EDUs if there is room.
pending_edus.extend(
self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
)
while (
len(pending_edus) < MAX_EDUS_PER_TRANSACTION
and self.queue._pending_edus_keyed
):
_, val = self.queue._pending_edus_keyed.popitem()
pending_edus.append(val)
# Now we look for any PDUs to send, by getting up to 50 PDUs from the
# queue
self._pdus = self.queue._pending_pdus[:50]
if not self._pdus and not pending_edus:
return [], []
# if we've decided to send a transaction anyway, and we have room, we
# may as well send any pending RRs
if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
if self._pdus:
self._last_stream_ordering = self._pdus[
-1
].internal_metadata.stream_ordering
assert self._last_stream_ordering
return self._pdus, pending_edus
async def __aexit__(self, exc_type, exc, tb):
if exc_type is not None:
# Failed to send transaction, so we bail out.
return
# Successfully sent transactions, so we remove pending PDUs from the queue
if self._pdus:
self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
# Succeeded to send the transaction so we record where we have sent up
# to in the various streams
if self._device_stream_id:
await self.queue._store.delete_device_msgs_for_remote(
self.queue._destination, self._device_stream_id
)
self.queue._last_device_stream_id = self._device_stream_id
# also mark the device updates as sent
if self._device_list_id:
logger.info(
"Marking as sent %r %r", self.queue._destination, self._device_list_id
)
await self.queue._store.mark_as_sent_devices_by_remote(
self.queue._destination, self._device_list_id
)
self.queue._last_device_list_stream_id = self._device_list_id
if self._last_stream_ordering:
# we sent some PDUs and it was successful, so update our
# last_successful_stream_ordering in the destinations table.
await self.queue._store.set_destination_last_successful_stream_ordering(
self.queue._destination, self._last_stream_ordering
)

View file

@ -36,9 +36,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
last_pdu_age_metric = Gauge(
"synapse_federation_last_sent_pdu_age",
"The age (in seconds) of the last PDU successfully sent to the given domain",
last_pdu_ts_metric = Gauge(
"synapse_federation_last_sent_pdu_time",
"The timestamp of the last PDU which was successfully sent to the given domain",
labelnames=("server_name",),
)
@ -69,15 +69,12 @@ class TransactionManager:
destination: str,
pdus: List[EventBase],
edus: List[Edu],
) -> bool:
) -> None:
"""
Args:
destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
Returns:
True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from
@ -96,8 +93,6 @@ class TransactionManager:
edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts):
success = True
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
@ -152,45 +147,29 @@ class TransactionManager:
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response", destination, txn_id, code
)
raise e
set_tag(tags.ERROR, True)
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
raise
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warning(
"TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
e_id,
r,
)
else:
for p in pdus:
logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warning(
"TX [%s] {%s} Failed to send event %s",
"TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
p.event_id,
e_id,
r,
)
success = False
if success and pdus and destination in self._federation_metrics_domains:
if pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
last_pdu_age_metric.labels(server_name=destination).set(
last_pdu_age / 1000
last_pdu_ts_metric.labels(server_name=destination).set(
last_pdu.origin_server_ts / 1000
)
set_tag(tags.ERROR, not success)
return success

View file

@ -73,7 +73,9 @@ class AcmeHandler:
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
self.reactor.listenTCP(
self.hs.config.acme_port, srv, backlog=50, interface=host
)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)

View file

@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
@attr.s(slots=True, frozen=True)
class LoginTokenAttributes:
"""Data we store in a short-term login token"""
user_id = attr.ib(type=str)
# the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@ -326,7 +337,8 @@ class AuthHandler(BaseHandler):
user is too high to proceed
"""
if not requester.access_token_id:
raise ValueError("Cannot validate a user without an access token")
if self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
@ -1164,18 +1176,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
async def validate_short_term_login_token(
self, login_token: str
) -> LoginTokenAttributes:
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", user_id)
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(user_id)
return user_id
await self.auth.check_auth_blocking(res.user_id)
return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@ -1204,7 +1214,7 @@ class AuthHandler(BaseHandler):
async def delete_access_tokens_for_user(
self,
user_id: str,
except_token_id: Optional[str] = None,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
@ -1397,6 +1407,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@ -1406,6 +1417,9 @@ class AuthHandler(BaseHandler):
Args:
registered_user_id: The registered user ID to complete SSO login for.
auth_provider_id: The id of the SSO Identity provider that was used for
login. This will be stored in the login token for future tracking in
prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
@ -1427,6 +1441,7 @@ class AuthHandler(BaseHandler):
self._complete_sso_login(
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
extra_attributes,
@ -1437,6 +1452,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@ -1463,7 +1479,7 @@ class AuthHandler(BaseHandler):
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@ -1569,15 +1585,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
auth_provider_id: str,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
"""Verify a short-term-login macaroon
Checks that the given token is a valid, unexpired short-term-login token
minted by this server.
Args:
token: the login token to verify
Returns:
the user_id that this token is valid for
Raises:
MacaroonVerificationFailedException if the verification failed
"""
macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")

View file

@ -83,6 +83,7 @@ class CasHandler:
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
self.unstable_idp_brand = None
self._sso_handler = hs.get_sso_handler()

View file

@ -201,7 +201,7 @@ class FederationHandler(BaseHandler):
or pdu.internal_metadata.is_outlier()
)
if already_seen:
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
logger.debug("Already seen pdu")
return
# do some initial sanity-checking of the event. In particular, make
@ -210,18 +210,14 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
logger.warning(
"[%s %s] Received event failed sanity checks", room_id, event_id
)
logger.warning("Received event failed sanity checks")
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
"[%s %s] Queuing PDU from %s for now: join in progress",
room_id,
event_id,
"Queuing PDU from %s for now: join in progress",
origin,
)
self.room_queues[room_id].append((pdu, origin))
@ -236,9 +232,7 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
room_id,
event_id,
"Ignoring PDU from %s as we're not in the room",
origin,
)
return None
@ -250,7 +244,7 @@ class FederationHandler(BaseHandler):
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
logger.debug("min_depth: %d", min_depth)
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
@ -267,17 +261,13 @@ class FederationHandler(BaseHandler):
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
room_id,
event_id,
"Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
event_id,
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
)
@ -297,9 +287,7 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
"[%s %s] Found all missing prev_events",
room_id,
event_id,
"Found all missing prev_events",
)
if prevs - seen:
@ -329,9 +317,7 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warning(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
room_id,
event_id,
"Rejecting: failed to fetch %d prev events: %s",
len(prevs - seen),
shortstr(prevs - seen),
)
@ -367,17 +353,16 @@ class FederationHandler(BaseHandler):
# Ask the remote server for the states we don't
# know about
for p in prevs - seen:
logger.info(
"Requesting state at missing prev_event %s",
event_id,
)
logger.info("Requesting state after missing prev_event %s", p)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
(remote_state, _,) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
remote_state = (
await self._get_state_after_missing_prev_event(
origin, room_id, p
)
)
remote_state_map = {
@ -414,10 +399,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
room_id,
event_id,
"Error attempting to resolve state at missing " "prev_events",
exc_info=True,
)
raise FederationError(
@ -454,9 +436,7 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
"[%s %s]: Requesting missing events between %s and %s",
room_id,
event_id,
"Requesting missing events between %s and %s",
shortstr(latest),
event_id,
)
@ -523,15 +503,11 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
logger.warning(
"[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
)
logger.warning("Failed to get prev_events: %s", e)
return
logger.info(
"[%s %s]: Got %d prev_events: %s",
room_id,
event_id,
"Got %d prev_events: %s",
len(missing_events),
shortstr(missing_events),
)
@ -542,9 +518,7 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
"[%s %s] Handling received prev_event %s",
room_id,
event_id,
"Handling received prev_event %s",
ev.event_id,
)
with nested_logging_context(ev.event_id):
@ -553,9 +527,7 @@ class FederationHandler(BaseHandler):
except FederationError as e:
if e.code == 403:
logger.warning(
"[%s %s] Received prev_event %s failed history check.",
room_id,
event_id,
"Received prev_event %s failed history check.",
ev.event_id,
)
else:
@ -566,7 +538,6 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
@ -574,11 +545,9 @@ class FederationHandler(BaseHandler):
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
Returns:
A list of events in the state, possibly including the event itself, and
A list of events in the state, not including the event itself, and
a list of events in the auth chain for the given event.
"""
(
@ -590,9 +559,6 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
if include_event_in_state:
desired_events.add(event_id)
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@ -609,13 +575,6 @@ class FederationHandler(BaseHandler):
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
if include_event_in_state:
remote_event = event_map.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
@ -689,6 +648,131 @@ class FederationHandler(BaseHandler):
return fetched_events
async def _get_state_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, including the event itself
"""
# TODO: This function is basically the same as _get_state_for_room. Can
# we make backfill() use it, rather than having two code paths? I think the
# only difference is that backfill() persists the prev events separately.
(
state_event_ids,
auth_event_ids,
) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
logger.debug(
"state_ids returned %i state events, %i auth events",
len(state_event_ids),
len(auth_event_ids),
)
# start by just trying to fetch the events from the store
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self.store.get_events(
desired_events, allow_rejected=True
)
missing_desired_events = desired_events - fetched_events.keys()
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
)
# We probably won't need most of the auth events, so let's just check which
# we have for now, rather than thrashing the event cache with them all
# unnecessarily.
# TODO: we probably won't actually need all of the auth events, since we
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events.difference_update(
await self.store.have_seen_events(missing_auth_events)
)
logger.debug("We are also missing %i auth events", len(missing_auth_events))
missing_events = missing_desired_events | missing_auth_events
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
(await self.store.get_events(missing_desired_events, allow_rejected=True))
)
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
del fetched_events[bad_event_id]
# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
event_id,
failed_to_fetch,
)
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
return remote_state
async def _process_received_pdu(
self,
origin: str,
@ -707,10 +791,7 @@ class FederationHandler(BaseHandler):
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
room_id = event.room_id
event_id = event.event_id
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)
@ -871,7 +952,6 @@ class FederationHandler(BaseHandler):
destination=dest,
room_id=room_id,
event_id=e_id,
include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@ -1317,7 +1397,7 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)
@ -1580,7 +1660,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(state_ids)
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
state = await self.store.get_events(list(prev_state_ids.values()))
@ -2219,7 +2299,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell

View file

@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
hs, "initial_sync_cache"
hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,13 +15,13 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
from authlib.jose import JsonWebToken
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@ -28,20 +29,26 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
MacaroonInitException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.config.oidc_config import OidcProviderConfig
from synapse.config.oidc_config import (
OidcProviderClientSecretJwtKey,
OidcProviderConfig,
)
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -211,7 +218,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except (MacaroonDeserializationException, ValueError) as e:
except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@ -275,9 +282,21 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
client_secret = None # type: Union[None, str, JwtClientSecret]
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
client_secret = JwtClientSecret(
provider.client_secret_jwt_key,
provider.client_id,
provider.issuer,
hs.get_clock(),
)
self._client_auth = ClientAuth(
provider.client_id,
provider.client_secret,
client_secret,
provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
@ -312,6 +331,9 @@ class OidcProvider:
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
# Optional brand identifier for the unstable API (see MSC2858).
self.unstable_idp_brand = provider.unstable_idp_brand
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@ -521,7 +543,7 @@ class OidcProvider:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
headers = {
raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@ -535,10 +557,10 @@ class OidcProvider:
body = urlencode(args, True)
# Fill the body/headers with credentials
uri, headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=headers, body=body
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
headers = {k: [v] for (k, v) in headers.items()}
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
@ -745,7 +767,7 @@ class OidcProvider:
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
ui_auth_session_id=ui_auth_session_id or "",
),
)
@ -976,6 +998,81 @@ class OidcProvider:
return str(remote_user_id)
# number of seconds a newly-generated client secret should be valid for
CLIENT_SECRET_VALIDITY_SECONDS = 3600
# minimum remaining validity on a client secret before we should generate a new one
CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
class JwtClientSecret:
"""A class which generates a new client secret on demand, based on a JWK
This implementation is designed to comply with the requirements for Apple Sign in:
https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
but it's worth noting that we still put the generated secret in the "client_secret"
field (or rather, whereever client_auth_method puts it) rather than in a
client_assertion field in the body as that RFC seems to require.
"""
def __init__(
self,
key: OidcProviderClientSecretJwtKey,
oauth_client_id: str,
oauth_issuer: str,
clock: Clock,
):
self._key = key
self._oauth_client_id = oauth_client_id
self._oauth_issuer = oauth_issuer
self._clock = clock
self._cached_secret = b""
self._cached_secret_replacement_time = 0
def __str__(self):
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")
def __bytes__(self):
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()
def _get_secret(self) -> bytes:
now = self._clock.time()
# if we have enough validity on our existing secret, use it
if now < self._cached_secret_replacement_time:
return self._cached_secret
issued_at = int(now)
expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
# we copy the configured header because jwt.encode modifies it.
header = dict(self._key.jwt_header)
# see https://tools.ietf.org/html/rfc7523#section-3
payload = {
"sub": self._oauth_client_id,
"aud": self._oauth_issuer,
"iat": issued_at,
"exp": expires_at,
**self._key.jwt_payload,
}
logger.info(
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
)
self._cached_secret = jwt.encode(header, payload, self._key.key)
self._cached_secret_replacement_time = (
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
)
return self._cached_secret
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
@ -1020,10 +1117,9 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@ -1046,7 +1142,7 @@ class OidcSessionTokenGenerator:
The data extracted from the session cookie
Raises:
ValueError if an expected caveat is missing from the macaroon.
KeyError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@ -1057,26 +1153,16 @@ class OidcSessionTokenGenerator:
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
nonce = get_value_from_macaroon(macaroon, "nonce")
idp_id = get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
@ -1084,33 +1170,6 @@ class OidcSessionTokenGenerator:
ui_auth_session_id=ui_auth_session_id,
)
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
ValueError: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
@attr.s(frozen=True, slots=True)
class OidcSessionData:
@ -1125,8 +1184,8 @@ class OidcSessionData:
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
# The session ID of the ongoing UI Auth ("" if this is a login)
ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(

View file

@ -285,7 +285,7 @@ class PaginationHandler:
except Exception:
f = Failure()
logger.error(
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:

View file

@ -16,7 +16,9 @@
"""Contains functions for registering clients."""
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@ -41,6 +43,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
registration_counter = Counter(
"synapse_user_registrations_total",
"Number of new users registered (since restart)",
["guest", "shadow_banned", "auth_provider"],
)
login_counter = Counter(
"synapse_user_logins_total",
"Number of user logins (since restart)",
["guest", "auth_provider"],
)
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@ -67,6 +82,7 @@ class RegistrationHandler(BaseHandler):
)
else:
self.device_handler = hs.get_device_handler()
self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@ -161,6 +177,7 @@ class RegistrationHandler(BaseHandler):
bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None,
) -> str:
"""Registers a new client on the server.
@ -186,8 +203,9 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
auth_provider_id: The SSO IdP the user used, if any.
Returns:
The registere user_id.
The registered user_id.
Raises:
SynapseError if there was a problem registering.
"""
@ -197,6 +215,7 @@ class RegistrationHandler(BaseHandler):
threepid,
localpart,
user_agent_ips or [],
auth_provider_id=auth_provider_id,
)
if result == RegistrationBehaviour.DENY:
@ -287,6 +306,12 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
fail_count += 1
registration_counter.labels(
guest=make_guest,
shadow_banned=shadow_banned,
auth_provider=(auth_provider_id or ""),
).inc()
if not self.hs.config.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
@ -645,6 +670,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@ -655,21 +681,40 @@ class RegistrationHandler(BaseHandler):
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
Returns:
Tuple of device ID and access token
"""
res = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
)
if self.hs.config.worker_app:
r = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
)
return r["device_id"], r["access_token"]
login_counter.labels(
guest=is_guest,
auth_provider=(auth_provider_id or ""),
).inc()
return res["device_id"], res["access_token"]
async def register_device_inner(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
) -> Dict[str, str]:
"""Helper for register_device
Does the bits that need doing on the main process. Not for use outside this
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@ -694,7 +739,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
return (registered_device_id, access_token)
return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]

View file

@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid

View file

@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
hs, "room_list"
hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(

View file

@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]

View file

@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
"""Optional branding identifier"""
return None
@property
def unstable_idp_brand(self) -> Optional[str]:
"""Optional brand identifier for the unstable API (see MSC2858)."""
return None
@abc.abstractmethod
async def handle_redirect_request(
self,
@ -456,6 +461,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
@ -605,6 +611,7 @@ class SsoHandler:
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
auth_provider_id=auth_provider_id,
)
await self._store.record_user_external_id(
@ -886,6 +893,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,

View file

@ -244,7 +244,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
hs, "sync"
hs.get_clock(), "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()

View file

@ -39,12 +39,15 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
ITCPTransport,
)
from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@ -56,13 +59,20 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
from twisted.web.iweb import (
UNKNOWN_LENGTH,
IAgent,
IBodyProducer,
IPolicyForHTTPS,
IResponse,
)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@ -150,16 +160,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
for i in addresses:
ip_address = IPAddress(i.host)
for address in addresses:
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
# should go through this path.
if not isinstance(address, (IPv4Address, IPv6Address)):
continue
ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@ -174,15 +185,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
for i in addresses:
r.addressResolved(i)
r.resolutionComplete()
for address in addresses:
recv.addressResolved(address)
recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@ -196,10 +207,10 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
return r
return recv
@implementer(IReactorPluggableNameResolver)
@implementer(ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@ -325,7 +336,7 @@ class SimpleHttpClient:
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
)
) # type: ISynapseReactor
else:
self.reactor = hs.get_reactor()
@ -346,7 +357,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
)
) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@ -752,6 +763,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
transport = None # type: Optional[ITCPTransport]
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@ -763,18 +776,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
transport = None # type: Optional[ITCPTransport]
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@ -797,9 +813,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
@ -868,6 +885,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.

View file

@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
from synapse.util import Clock
logger = logging.getLogger(__name__)
@ -68,7 +69,7 @@ class MatrixFederationAgent:
def __init__(
self,
reactor: IReactorCore,
reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_blacklist: IPSet,

View file

@ -322,7 +322,8 @@ def _cache_period_from_headers(
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
for hdr in headers.getRawHeaders(b"cache-control", []):
cache_control_headers = headers.getRawHeaders(b"cache-control") or []
for hdr in cache_control_headers:
for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()

View file

@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
from synapse.types import JsonDict
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
)
) # type: ISynapseReactor
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
self.agent = MatrixFederationAgent(
federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent,
federation_agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
@ -534,9 +534,10 @@ class MatrixFederationHttpClient:
response.code, response_phrase, body
)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
# Retry if the error is a 5xx or a 429 (Too Many
# Requests), otherwise just raise a standard
# `HttpResponseException`
if 500 <= response.code < 600 or response.code == 429:
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise exc

View file

@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
transport = attr.ib(type=ITransport)
# This is essentially ITCPTransport, but that is missing certain fields
# (connected and registerProducer) which are part of the implementation.
transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
# Force recognising transport as a Connection and not the more
# generic ITransport.
transport = result.transport # type: Connection # type: ignore
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and result.transport is self._producer.transport:
if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
transport=result.transport,
transport=transport,
format=self.format,
)
result.transport.registerProducer(self._producer, True)
transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
self._connection_waiter.addCallbacks(writer, fail)
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
deferred.addCallbacks(writer, fail)
self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""

View file

@ -669,7 +669,7 @@ def preserve_fn(f):
return g
def run_in_background(f, *args, **kwargs):
def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res)
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
return res
return defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can

View file

@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "",
) -> str:
"""Generate a login token suitable for m.login.token authentication"""
"""Generate a login token suitable for m.login.token authentication
Args:
user_id: gives the ID of the user that the token is for
duration_in_ms: the time that the token will be valid for
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
to get this token, if any. This is encoded in the token so that
/login can report stats on number of successful logins by IdP.
"""
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id, duration_in_ms
user_id,
auth_provider_id,
duration_in_ms,
)
@defer.inlineCallbacks
@ -276,6 +291,7 @@ class ModuleApi:
"""
self._auth_handler._complete_sso_login(
registered_user_id,
"<unknown>",
request,
client_redirect_url,
)
@ -286,6 +302,7 @@ class ModuleApi:
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@ -299,9 +316,15 @@ class ModuleApi:
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
auth_provider_id: the ID of the SSO IdP which was used to log in. This
is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url, new_user=new_user
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
new_user=new_user,
)
@defer.inlineCallbacks

View file

@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams
@ -66,7 +66,7 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall]
self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False

View file

@ -18,7 +18,7 @@ import logging
import re
import urllib
from inspect import signature
from typing import Dict, List, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
_pending_outgoing_requests = Gauge(
@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
CACHE = True
RETRY_ON_TIMEOUT = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache = ResponseCache(
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we

View file

@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device(
res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
return 200, {"device_id": device_id, "access_token": access_token}
return 200, res
def register_servlets(hs, http_server):

View file

@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection,
conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@ -347,7 +347,7 @@ class ReplicationCommandHandler:
)
def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
def on_FEDERATION_ACK(
self, conn: IReplicationConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection):
def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)
@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
def lost_connection(self, connection: AbstractConnection):
def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.

View file

@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
import abc
import fcntl
import logging
import struct
@ -54,8 +53,10 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
from zope.interface import Interface, implementer
from twisted.internet import task
from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@ -121,6 +122,14 @@ class ConnectionStates:
CLOSED = "closed"
class IReplicationConnection(Interface):
"""An interface for replication connections."""
def send_command(cmd: Command):
"""Send the command down the connection"""
@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
# The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation.
transport = None # type: Connection
delimiter = b"\n"
# Valid commands we expect to receive
@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
class AbstractConnection(abc.ABC):
"""An interface for replication connections."""
@abc.abstractmethod
def send_command(self, cmd: Command):
"""Send the command down the connection"""
pass
# This tells python that `BaseReplicationStreamProtocol` implements the
# interface.
AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(

View file

@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import IAddress, IConnector
from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
AbstractConnection,
IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
@implementer(IReplicationConnection)
class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
(b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
except Exception:
logger.warning("Failed to send ping to a redis connection")
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
# it's rubbish. We add our own here.
def startedConnecting(self, connector: IConnector):
logger.info(
"Connecting to redis server %s", format_address(connector.getDestination())
)
super().startedConnecting(connector)
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
logger.info(
"Connection to redis server %s failed: %s",
format_address(connector.getDestination()),
reason.value,
)
super().clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector: IConnector, reason: Failure):
logger.info(
"Connection to redis server %s lost: %s",
format_address(connector.getDestination()),
reason.value,
)
super().clientConnectionLost(connector, reason)
def format_address(address: IAddress) -> str:
if isinstance(address, (IPv4Address, IPv6Address)):
return "%s:%i" % (address.host, address.port)
return str(address)
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
@ -328,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, 30)
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler

View file

@ -15,10 +15,9 @@
import re
import twisted.web.server
import synapse.api.auth
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest
from synapse.types import UserID
@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns
async def assert_requester_is_admin(
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
) -> None:
async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
"""Verify that the requester is an admin user
Args:
auth: api.auth.Auth singleton
auth: Auth singleton
request: incoming request
Raises:
@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user)
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
Args:
auth: api.auth.Auth singleton
auth: Auth singleton
user_id: user to check
Raises:

View file

@ -17,10 +17,9 @@
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
self.auth = hs.get_auth()
async def on_POST(
self, request: Request, server_name: str, media_id: str
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_DELETE(
self, request: Request, server_name: str, media_id: str
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, server_name: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)

View file

@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Tuple
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)

View file

@ -685,7 +685,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
results["state"],
time_now,
# No need to bundle aggregations for state events
bundle_aggregations=False,
)
return 200, results

View file

@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
def register(self, json_resource):
def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
async def on_POST(self, request, txn_id=None):
async def on_POST(
self, request: SynapseRequest, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
def on_PUT(self, request, txn_id):
def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)

View file

@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
if (
"password" not in body
and self.hs.config.password_localdb_enabled
):
raise SynapseError(
400, "Must provide a password to re-activate an account."
)

View file

@ -14,10 +14,12 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
sso_flow = {
"type": LoginRestServlet.SSO_TYPE,
"identity_providers": [
_get_auth_flow_dict_for_idp(
idp,
)
for idp in self._sso_handler.get_identity_providers().values()
],
} # type: JsonDict
if self._msc2858_enabled:
# backwards-compatibility support for clients which don't
# support the stable API yet
sso_flow["org.matrix.msc2858.identity_providers"] = [
_get_auth_flow_dict_for_idp(idp)
_get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
for idp in self._sso_handler.get_identity_providers().values()
]
@ -219,6 +231,7 @@ class LoginRestServlet(RestServlet):
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@ -234,6 +247,8 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@ -256,7 +271,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@ -283,12 +298,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
user_id, login_submission, self.auth_handler._sso_login_callback
res.user_id,
login_submission,
self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
@ -327,22 +343,38 @@ class LoginRestServlet(RestServlet):
return result
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
def _get_auth_flow_dict_for_idp(
idp: SsoIdentityProvider, use_unstable_brands: bool = False
) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
Args:
idp: the identity provider to describe
use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
# use the stable brand identifier if the unstable identifier isn't defined.
if use_unstable_brands and idp.unstable_idp_brand:
e["brand"] = idp.unstable_idp_brand
return e
class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
re.compile(
"^"
+ CLIENT_API_PREFIX
+ "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
)
]
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@ -360,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
# expose additional endpoint for MSC2858 support
# expose additional endpoint for MSC2858 support: backwards-compat support
# for clients which don't yet support the stable endpoints.
http_server.register_paths(
"GET",
client_patterns(

View file

@ -674,7 +674,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
results["state"],
time_now,
# No need to bundle aggregations for state events
bundle_aggregations=False,
)
return 200, results

View file

@ -32,6 +32,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.types import GroupID, JsonDict
from ._base import client_patterns
@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, category_id: Optional[str], room_id: str
self,
request: SynapseRequest,
group_id: str,
category_id: Optional[str],
room_id: str,
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, category_id: str, room_id: str
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_GET(
self, request: Request, group_id: str, category_id: str
self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, category_id: str
self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, category_id: str
self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_GET(
self, request: Request, group_id: str, role_id: str
self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, role_id: str
self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, role_id: str
self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, role_id: Optional[str], user_id: str
self,
request: SynapseRequest,
group_id: str,
role_id: Optional[str],
user_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, role_id: str, user_id: str
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, room_id: str
self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, room_id: str
self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, room_id: str, config_key: str
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id, user_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id, user_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.store = hs.get_datastore()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()

View file

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)

View file

@ -35,6 +35,7 @@ from synapse.api.errors import (
from synapse.config._base import ConfigError
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@ -145,7 +146,7 @@ class MediaRepository:
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: str,
auth_user: UserID,
) -> str:
"""Store uploaded content for a local user and return the mxc URL

View file

@ -39,6 +39,7 @@ from synapse.http.server import (
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
@ -174,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)

View file

@ -96,9 +96,14 @@ class Thumbnailer:
def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
if self.image.mode in ["1", "P"]:
self.image = self.image.convert("RGB")
# looks awful.
#
# If the image has transparency, use RGBA instead.
if self.image.mode in ["1", "L", "P"]:
mode = "RGB"
if self.image.info.get("transparency", None) is not None:
mode = "RGBA"
self.image = self.image.convert(mode)
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:

View file

@ -22,6 +22,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: Request) -> None:
async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point

View file

@ -14,24 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
if TYPE_CHECKING:
from synapse.server import HomeServer
class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
self._saml_handler._render_error(
self._sso_handler.render_error(
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)

View file

@ -36,7 +36,6 @@ from typing import (
cast,
)
import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString
from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
def get_reactor(self) -> twisted.internet.base.ReactorBase:
def get_reactor(self) -> ISynapseReactor:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
@ -352,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
return (
InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else RegularPolicyForHTTPS()
)
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
return InsecureInterceptableContextFactory()
return RegularPolicyForHTTPS()
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:

View file

@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
An awaitable which resolve to a list of event_ids
list of event_ids
"""
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
self._get_auth_chain_ids_using_cover_index_txn,
room_id,
event_ids,
include_given,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs using the chain index."""
# First we look up the chain ID/sequence numbers for the given events.
initial_events = set(event_ids)
# All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str]
# A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int]
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
seen_events.add(event_id)
event_chains[chain_id] = max(
sequence_number, event_chains.get(chain_id, 0)
)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(seen_events)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int]
# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
)
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
# Now for each chain we figure out the maximum sequence number reachable
# from *any* event ID. Events with a sequence less than that are in the
# auth chain.
if include_given:
results = initial_events
else:
results = set()
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND sequence_number <= ?
"""
for chain_id, max_no in chains.items():
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)
return list(results)
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
"""
if include_given:
results = set(event_ids)
else:

View file

@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._chain_cover_index,
)
self.db_pool.updates.register_background_update_handler(
"purged_chain_cover",
self._purged_chain_cover_index,
)
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
processed_count=count,
finished_room_map=finished_rooms,
)
async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
"""
A background updates that iterates over the chain cover and deletes the
chain cover for events that have been purged.
This may be due to fully purging a room or via setting a retention policy.
"""
current_event_id = progress.get("current_event_id", "")
def purged_chain_cover_txn(txn) -> int:
# The event ID from events will be null if the chain ID / sequence
# number points to a purged event.
sql = """
SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
FROM event_auth_chains
LEFT JOIN events AS e USING (event_id)
WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
"""
txn.execute(sql, (current_event_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
# The event IDs and chain IDs / sequence numbers where the event has
# been purged.
unreferenced_event_ids = []
unreferenced_chain_id_tuples = []
event_id = ""
for event_id, chain_id, sequence_number, has_event in rows:
if not has_event:
unreferenced_event_ids.append((event_id,))
unreferenced_chain_id_tuples.append((chain_id, sequence_number))
# Delete the unreferenced auth chains from event_auth_chain_links and
# event_auth_chains.
txn.executemany(
"""
DELETE FROM event_auth_chains WHERE event_id = ?
""",
unreferenced_event_ids,
)
# We should also delete matching target_*, but there is no index on
# target_chain_id. Hopefully any purged events are due to a room
# being fully purged and they will be removed from the origin_*
# searches.
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
origin_chain_id = ? AND origin_sequence_number = ?
""",
unreferenced_chain_id_tuples,
)
progress = {
"current_event_id": event_id,
}
self.db_pool.updates._background_update_progress_txn(
txn, "purged_chain_cover", progress
)
return len(rows)
result = await self.db_pool.runInteraction(
"_purged_chain_cover_index",
purged_chain_cover_txn,
)
if not result:
await self.db_pool.updates._end_background_update("purged_chain_cover")
return result

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
import threading
from collections import namedtuple
@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
set[str]: The events we have already seen.
"""
results = set()
# if the event cache contains the event, obviously we've seen it.
results = {x for x in event_ids if self._get_event_cache.contains(x)}
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE "
@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore):
txn.database_engine, "e.event_id", chunk
)
txn.execute(sql + clause, args)
for (event_id,) in txn:
results.add(event_id)
results.update(row[0] for row in txn)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
for chunk in batch_iter((x for x in event_ids if x not in results), 100):
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)

View file

@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
(origin_chain_id = ? AND origin_sequence_number = ?) OR
(target_chain_id = ? AND target_sequence_number = ?)
origin_chain_id = ? AND origin_sequence_number = ?
""",
(
(chain_id, seq_num, chain_id, seq_num)
for (chain_id, seq_num) in referenced_chain_id_tuples
),
referenced_chain_id_tuples,
)
# Now we delete tables which lack an index on room_id but have one on event_id

View file

@ -16,7 +16,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[str] = None,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
values = [v for _, v in items] # type: List[Union[str, int]]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)

View file

@ -0,0 +1,17 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5910, 'purged_chain_cover', '{}');

View file

@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore):
self.db_pool.simple_upsert_many_txn(
txn,
"destination_rooms",
["destination", "room_id"],
rows,
["stream_ordering"],
[(stream_ordering,)] * len(rows),
table="destination_rooms",
key_names=("destination", "room_id"),
key_values=rows,
value_names=["stream_ordering"],
value_values=[(stream_ordering,)] * len(rows),
)
async def get_destination_last_successful_stream_ordering(

View file

@ -35,6 +35,14 @@ from typing import (
import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from zope.interface import Interface
from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
IReactorTCP,
IReactorTime,
)
from synapse.api.errors import Codes, SynapseError
from synapse.util.stringutils import parse_and_validate_server_name
@ -67,33 +75,40 @@ MutableStateMap = MutableMapping[StateKey, T]
JsonDict = Dict[str, Any]
class Requester(
namedtuple(
"Requester",
[
"user",
"access_token_id",
"is_guest",
"shadow_banned",
"device_id",
"app_service",
"authenticated_entity",
],
)
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
):
"""The interfaces necessary for Synapse to function."""
@attr.s(frozen=True, slots=True)
class Requester:
"""
Represents the user making a request
Attributes:
user (UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
user: id of the user making the request
access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
is_guest: True if the user making this request is a guest user
shadow_banned: True if the user making this request has been shadow-banned.
device_id: device_id which was set at authentication time
app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
"""
user = attr.ib(type="UserID")
access_token_id = attr.ib(type=Optional[int])
is_guest = attr.ib(type=bool)
shadow_banned = attr.ib(type=bool)
device_id = attr.ib(type=Optional[str])
app_service = attr.ib(type=Optional["ApplicationService"])
authenticated_entity = attr.ib(type=str)
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@ -141,23 +156,23 @@ class Requester(
def create_requester(
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
is_guest: Optional[bool] = False,
shadow_banned: Optional[bool] = False,
is_guest: bool = False,
shadow_banned: bool = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
):
) -> Requester:
"""
Create a new ``Requester`` object
Args:
user_id (str|UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
user_id: id of the user making the request
access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
is_guest: True if the user making this request is a guest user
shadow_banned: True if the user making this request is shadow-banned.
device_id: device_id which was set at authentication time
app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.

View file

@ -76,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
observer = self._observers.pop()
try:
# TODO: Handle errors here.
self._observers.pop().callback(r)
except Exception:
pass
observer.callback(r)
except Exception as e:
logger.exception(
"%r threw an exception on .callback(%r), ignoring...",
observer,
r,
exc_info=e,
)
return r
def errback(f):
@ -90,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
observer = self._observers.pop()
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
except Exception:
pass
observer.errback(f)
except Exception as e:
logger.exception(
"%r threw an exception on .errback(%r), ignoring...",
observer,
f,
exc_info=e,
)
if consumeErrors:
return None

View file

@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
T = TypeVar("T")
@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.clock = hs.get_clock()
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name

89
synapse/util/macaroons.py Normal file
View file

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for manipulating macaroons"""
from typing import Callable, Optional
import pymacaroons
from pymacaroons.exceptions import MacaroonVerificationFailedException
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
and returns the extracted value.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
MacaroonVerificationFailedException: if there are conflicting values for the
caveat in the macaroon, or if the caveat was not found in the macaroon.
"""
prefix = key + " = "
result = None # type: Optional[str]
for caveat in macaroon.caveats:
if not caveat.caveat_id.startswith(prefix):
continue
val = caveat.caveat_id[len(prefix) :]
if result is None:
# first time we found this caveat: record the value
result = val
elif val != result:
# on subsequent occurrences, raise if the value is different.
raise MacaroonVerificationFailedException(
"Conflicting values for caveat " + key
)
if result is not None:
return result
# If the caveat is not there, we raise a MacaroonVerificationFailedException.
# Note that it is insecure to generate a macaroon without all the caveats you
# might need (because there is nothing stopping people from adding extra caveats),
# so if the caveat isn't there, something odd must be going on.
raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
"""Make a macaroon verifier which accepts 'time' caveats
Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
the given macaroon verifier.
Args:
v: the macaroon verifier
get_time_ms: a callable which will return the timestamp after which the caveat
should be considered expired. Normally the current time.
"""
def verify_expiry_caveat(caveat: str):
time_msec = get_time_ms()
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
return time_msec < expiry
v.satisfy_general(verify_expiry_caveat)

View file

@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
else:
data = json_cb()
self.failed_pdus.extend(data["pdus"])
raise IOError("Failed to connect because this is a test!")
raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
"""

View file

@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
-----END PRIVATE KEY-----

View file

@ -0,0 +1,4 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
-----END PUBLIC KEY-----

View file

@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", "", 5000
)
self.assertEqual("a_user", user_id)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id)
self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
self.reactor.advance(6)
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
self.auth_handler.validate_short_term_login_token(token),
AuthError,
)
def test_short_term_login_token_gives_auth_provider(self):
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", auth_provider_id="my_idp"
)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", "", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
res = self.get_success(
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
self.assertEqual("a_user", user_id)
self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
),
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.large_number_of_users)
)
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.small_number_of_users)
)
self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
token = self.macaroon_generator.generate_short_term_login_token(
"user_a", "", 5000
)
return pymacaroons.Macaroon.deserialize(token)

View file

@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=False
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=False
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
)
@override_config(
@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Optional
import os
from urllib.parse import parse_qs, urlparse
from mock import ANY, Mock, patch
@ -23,6 +23,7 @@ import pymacaroons
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
JWKS_URI = ISSUER + ".well-known/jwks.json"
# config for common cases
COMMON_CONFIG = {
DEFAULT_CONFIG = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# extends the default config with explicit OAuth2 endpoints instead of using discovery
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
"token_endpoint": TOKEN_ENDPOINT,
@ -107,6 +119,32 @@ async def get_json(url):
return {"keys": []}
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
# this key was generated with:
# openssl ecparam -name prime256v1 -genkey -noout |
# openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
#
# we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
# that's what Apple use, and we want to be sure that we work with Apple's keys.
#
# (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
# keys using ASN.1. Both are then typically formatted using PEM, which says: use the
# base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
# really need to care about any of that.)
return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
def _public_key_file_path() -> str:
"""path to a file containing the public half of a test key"""
# this was generated with:
# openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
#
# See above about where oidc_test_key.p8 came from
return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
oidc_config = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
# override_config works as expected.
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
return config
def make_homeserver(self, reactor, clock):
@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.reset_mock()
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.provider
@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata()
@override_config({"oidc_config": {"skip_verification": True}})
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
@ -360,20 +387,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = self.handler._token_generator._get_value_from_macaroon(
macaroon, "state"
)
nonce = self.handler._token_generator._get_value_from_macaroon(
macaroon, "nonce"
)
redirect = self.handler._token_generator._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
state = get_value_from_macaroon(macaroon, "state")
nonce = get_value_from_macaroon(macaroon, "nonce")
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
@ -385,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@ -434,7 +457,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None, new_user=True
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@ -465,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None, new_user=False
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
@ -486,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@ -528,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@ -613,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
"enabled": True,
"client_id": CLIENT_ID,
"issuer": ISSUER,
"client_auth_method": "client_secret_post",
"client_secret_jwt_key": {
"key_file": _key_file_path(),
"jwt_header": {"alg": "ES256", "kid": "ABC789"},
"jwt_payload": {"iss": "DEFGHI"},
},
}
}
)
def test_exchange_code_jwt_key(self):
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
)
)
code = "code"
# advance the clock a bit before we start, so we aren't working with zero
# timestamps.
self.reactor.advance(1000)
start_time = self.reactor.seconds()
ret = self.get_success(self.provider._exchange_code(code))
self.assertEqual(ret, token)
# the request should have hit the token endpoint
kwargs = self.http_client.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
# the client secret provided to the should be a jwt which can be checked with
# the public key
args = parse_qs(kwargs["data"].decode("utf-8"))
secret = args["client_secret"][0]
with open(_public_key_file_path()) as f:
key = f.read()
claims = jwt.decode(secret, key)
self.assertEqual(claims.header["kid"], "ABC789")
self.assertEqual(claims["aud"], ISSUER)
self.assertEqual(claims["iss"], "DEFGHI")
self.assertEqual(claims["sub"], CLIENT_ID)
self.assertEqual(claims["iat"], start_time)
self.assertGreater(claims["exp"], start_time)
# check the rest of the POSTed data
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
@override_config(
{
"oidc_config": {
"enabled": True,
"client_id": CLIENT_ID,
"issuer": ISSUER,
"client_auth_method": "none",
}
}
)
def test_exchange_code_no_auth(self):
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
)
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
self.assertEqual(ret, token)
# the request should have hit the token endpoint
kwargs = self.http_client.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
@override_config(
{
"oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
}
},
}
}
)
@ -651,12 +773,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
"@foo:test",
"oidc",
request,
client_redirect_url,
{"phone": "1234567"},
new_user=True,
)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
@ -668,7 +792,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, None, new_user=True
"@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@ -679,7 +803,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, None, new_user=True
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@ -697,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
@ -716,14 +840,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, new_user=False
user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, new_user=False
user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@ -738,7 +862,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, new_user=False
user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@ -774,9 +898,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, None, new_user=False
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
@ -787,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures"
}
},
}
}
)
@ -810,7 +936,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", ANY, ANY, None, new_user=True
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@ -834,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self):
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
@ -846,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.username }}"}
}
},
}
}
)
@ -866,7 +994,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str] = None,
ui_auth_session_id: str = "",
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
@ -909,6 +1037,7 @@ async def _make_callback_with_userinfo(
idp_id="oidc",
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id="",
),
)
request = _build_callback_request("code", state, session)

View file

@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
def test_spam_checker_receives_sso_type(self):
"""Test rejecting registration based on SSO type"""
class BanBadIdPUser:
def check_registration_for_spam(
self, email_threepid, username, request_info, auth_provider_id=None
):
# Reject any user coming from CAS and whose username contains profanity
if auth_provider_id == "cas" and "flimflob" in username:
return RegistrationBehaviour.DENY
return RegistrationBehaviour.ALLOW
# Configure a spam checker that denies a certain user on a specific IdP
spam_checker = self.hs.get_spam_checker()
spam_checker.spam_checkers = [BanBadIdPUser()]
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
SynapseError,
)
exception = f.value
# We return 429 from the spam checker for denied registrations
self.assertIsInstance(exception, SynapseError)
self.assertEqual(exception.code, 429)
# Check the same username can register using SAML
self.get_success(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
)
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):

View file

@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "", None, new_user=False
"@test_user:test", "saml", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "", None, new_user=False
"@test_user:test", "saml", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, "", None, new_user=True
"@test_user1:test", "saml", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)

View file

@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock
from netaddr import IPSet
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.api.errors import SynapseError
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
read_body_with_max_size,
)
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
class BlacklistingAgentTest(TestCase):
def setUp(self):
self.reactor, self.clock = get_clock()
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
# Configure the reactor's DNS resolver.
for (domain, ip) in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
def test_reactor(self):
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
),
)
# The unsafe domains and IPs should be rejected.
for domain in (self.unsafe_domain, self.unsafe_ip):
self.failureResultOf(
agent.request(b"GET", b"http://" + domain), DNSLookupError
)
# The safe domains IPs should be accepted.
for domain in (
self.safe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)
# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]
# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)
response = self.successResultOf(d)
self.assertEqual(response.code, 200)
def test_agent(self):
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
)
# The unsafe IPs should be rejected.
self.failureResultOf(
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
)
# The safe and unsafe domains and safe IPs should be accepted.
for domain in (
self.safe_domain,
self.unsafe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)
# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]
# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)
response = self.successResultOf(d)
self.assertEqual(response.code, 200)

View file

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
from twisted.web.server import Request, Site
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@ -32,7 +31,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.replication.tcp.resource import (
ReplicationStreamProtocolFactory,
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
from synapse.util import Clock
@ -59,7 +61,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None)
self.server = server_factory.buildProtocol(
None
) # type: ServerReplicationStreamProtocol
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@ -152,12 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -179,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return request_factory.request
return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@ -188,8 +188,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
path = request.path # type: bytes # type: ignore
self.assertRegex(
request.path,
path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
@ -232,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
b"localhost",
6379,
self.connect_any_redis_attempts,
)
@ -387,12 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self._hs_to_site[hs]
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -418,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@ -450,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@ -475,9 +457,13 @@ class _PushHTTPChannel(HTTPChannel):
makes it very hard to test.
"""
def __init__(self, reactor: IReactorTime):
def __init__(
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
self.requestFactory = request_factory
self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
@ -503,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@ -590,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
transport = None # type: Optional[FakeTransport]
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@ -634,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
assert self.transport is not None
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)

View file

@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
# wire up the ReplicationCommandHandler to a mock connection
mock_connection = mock.Mock(spec=AbstractConnection)
# wire up the ReplicationCommandHandler to a mock connection, which needs
# to implement IReplicationConnection. (Note that Mock doesn't understand
# interfaces, but casing an interface to a list gives the attributes.)
mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row

View file

@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
expected_flows = [
{"type": "m.login.cas"},
{"type": "m.login.sso"},
{"type": "m.login.token"},
{"type": "m.login.password"},
] + ADDITIONAL_LOGIN_FLOWS
expected_flow_types = [
"m.login.cas",
"m.login.sso",
"m.login.token",
"m.login.password",
] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
self.assertCountEqual(channel.json_body["flows"], expected_flows)
self.assertCountEqual(
[f["type"] for f in channel.json_body["flows"]], expected_flow_types
)
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self):
@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request(True, "xxx")
channel = self._make_sso_redirect_request(False, "xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request(False, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_msc2858_redirect_to_oidc(self):
"""Test the unstable API"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):

View file

@ -105,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
@attr.s
@attr.s(slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@ -117,13 +117,15 @@ class _TestImage:
test should just check for success.
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
expected_found: True if the file should exist on the server, or False if
a 404 is expected.
"""
data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
expected_cropped = attr.ib(type=Optional[bytes], default=None)
expected_scaled = attr.ib(type=Optional[bytes], default=None)
expected_found = attr.ib(default=True, type=bool)
@ -153,6 +155,21 @@ class _TestImage:
),
),
),
# small png with transparency.
(
_TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
b"4154789c636800000082008177cd72b60000000049454e44ae426"
b"082"
),
b"image/png",
b".png",
# Note that we don't check the output since it varies across
# different versions of Pillow.
),
),
# small lossless webp
(
_TestImage(
@ -162,8 +179,6 @@ class _TestImage:
),
b"image/webp",
b".webp",
None,
None,
),
),
# an empty file
@ -172,9 +187,7 @@ class _TestImage:
b"",
b"image/gif",
b".gif",
None,
None,
False,
expected_found=False,
),
),
],

View file

@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@ -188,7 +189,7 @@ class FakeSite:
def make_request(
reactor,
site: Site,
site: Union[Site, FakeSite],
method,
path,
content=b"",
@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock
@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""

View file

@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
# Mark the room as not having a cover index
# Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
return room_id
@parameterized.expand([(True,), (False,)])
def test_auth_chain_ids(self, use_chain_cover_index: bool):
room_id = self._setup_auth_chain(use_chain_cover_index)
# a and b have the same auth chain.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["a", "b"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
# d and e have the same auth chain.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
self.assertEqual(auth_chain_ids, ["k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
self.assertEqual(auth_chain_ids, ["j"])
# j and k have no parents.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
self.assertEqual(auth_chain_ids, [])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
self.assertEqual(auth_chain_ids, [])
# More complex input sequences.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["h", "i"])
)
self.assertCountEqual(auth_chain_ids, ["k", "j"])
# e gets returned even though include_given is false, but it is in the
# auth chain of b.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["b", "e"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
# Test include_given.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
)
self.assertCountEqual(auth_chain_ids, ["i", "j"])
@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
difference = self.get_success(

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import NotFoundError
from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_purge(self):
self.store = hs.get_datastore()
self.storage = self.hs.get_storage()
def test_purge_history(self):
"""
Purging a room will delete everything before the topological point.
Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
store = self.hs.get_datastore()
storage = self.hs.get_storage()
# Get the topological token
token = self.get_success(
store.get_topological_token_for_event(last["event_id"])
self.store.get_topological_token_for_event(last["event_id"])
)
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
self.get_success(
storage.purge_events.purge_history(self.room_id, token_str, True)
self.storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
self.get_success(store.get_event(last["event_id"]))
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
def test_purge_history_wont_delete_extrems(self):
"""
Purging a room will delete everything before the topological point.
Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
token = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
self.store.get_topological_token_for_event(last["event_id"])
)
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
f = self.get_failure(
self.storage.purge_events.purge_history(self.room_id, event, True),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
self.get_success(storage.get_event(first["event_id"]))
self.get_success(storage.get_event(second["event_id"]))
self.get_success(storage.get_event(third["event_id"]))
self.get_success(storage.get_event(last["event_id"]))
self.get_success(self.store.get_event(first["event_id"]))
self.get_success(self.store.get_event(second["event_id"]))
self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"]))
def test_purge_room(self):
"""
Purging a room will delete everything about it.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
state_handler = self.hs.get_state_handler()
create_event = self.get_success(
state_handler.get_current_state(self.room_id, "m.room.create", "")
)
self.assertIsNotNone(create_event)
# Purge everything before this topological token
self.get_success(self.storage.purge_events.purge_room(self.room_id))
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)

View file

@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
self.tx_log.emit( # type: ignore
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)

Some files were not shown because too many files have changed in this diff Show more