Merge remote-tracking branch 'upstream/release-v1.26.0'

This commit is contained in:
Tulir Asokan 2021-01-20 18:06:27 +02:00
commit d3547df958
183 changed files with 8528 additions and 2668 deletions

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from synapse.storage.engines import create_engine
logger = logging.getLogger("create_postgres_db")

2
.gitignore vendored
View file

@ -12,10 +12,12 @@
_trial_temp/
_trial_temp*/
/out
.DS_Store
# stuff that is likely to exist when you run a server locally
/*.db
/*.log
/*.log.*
/*.log.config
/*.pid
/.python-version

View file

@ -1,3 +1,90 @@
Synapse 1.26.0rc1 (2021-01-20)
==============================
This release brings a new schema version for Synapse and rolling back to a previous
verious is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
on these changes and for general upgrade guidance.
Features
--------
- Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177))
- During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091))
- Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159))
- Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024))
- Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984))
- Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086))
- Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932))
- Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948))
- Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130))
- Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068))
- Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092))
- Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166))
Bugfixes
--------
- Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023))
- Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028))
- Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051))
- Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053))
- Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054))
- Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059))
- Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070))
- Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071))
- Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116))
- Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117))
- Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128))
- Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108))
- Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145))
- Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161))
Improved Documentation
----------------------
- Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997))
- Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035))
- Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040))
- Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057))
- Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151))
Deprecations and Removals
-------------------------
- Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039))
Internal Changes
----------------
- Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124))
- Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939))
- Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016))
- Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018))
- Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025))
- Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030))
- Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031))
- Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033))
- Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038))
- Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041))
- Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055))
- Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058))
- Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063))
- Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069))
- Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080))
- Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093))
- Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098))
- Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106))
- Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112))
- Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125))
- Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144))
- Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146))
- Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157))
Synapse 1.25.0 (2021-01-13)
===========================

View file

@ -190,7 +190,8 @@ via brew and inform `pip` about it so that `psycopg2` builds:
```sh
brew install openssl@1.1
export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
export LDFLAGS="-L/usr/local/opt/openssl/lib"
export CPPFLAGS="-I/usr/local/opt/openssl/include"
```
##### OpenSUSE
@ -257,7 +258,7 @@ for a number of platforms.
#### Docker images and Ansible playbooks
There is an offical synapse image available at
There is an official synapse image available at
<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
the docker-compose file available at [contrib/docker](contrib/docker). Further
information on this including configuration options is available in the README

View file

@ -243,7 +243,7 @@ Then update the ``users`` table in the database::
Synapse Development
===================
Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
Before setting up a development environment for synapse, make sure you have the
system dependencies (such as the python header files) installed - see
@ -280,6 +280,27 @@ differ)::
PASSED (skips=15, successes=1322)
We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
./demo/start.sh
(to stop, you can use `./demo/stop.sh`)
If you just want to start a single instance of the app and run it directly:
# Create the homeserver.yaml config once
python -m synapse.app.homeserver \
--server-name my.domain.name \
--config-path homeserver.yaml \
--generate-config \
--report-stats=[yes|no]
# Start the app
python -m synapse.app.homeserver --config-path homeserver.yaml
Running the Integration Tests
=============================

View file

@ -85,6 +85,56 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.26.0
====================
Rolling back to v1.25.0 after a failed upgrade
----------------------------------------------
v1.26.0 includes a lot of large changes. If something problematic occurs, you
may want to roll-back to a previous version of Synapse. Because v1.26.0 also
includes a new database schema version, reverting that version is also required
alongside the generic rollback instructions mentioned above. In short, to roll
back to v1.25.0 you need to:
1. Stop the server
2. Decrease the schema version in the database:
.. code:: sql
UPDATE schema_version SET version = 58;
3. Delete the ignored users & chain cover data:
.. code:: sql
DROP TABLE IF EXISTS ignored_users;
UPDATE rooms SET has_auth_chain_index = false;
For PostgreSQL run:
.. code:: sql
TRUNCATE event_auth_chain_links;
TRUNCATE event_auth_chains;
For SQLite run:
.. code:: sql
DELETE FROM event_auth_chain_links;
DELETE FROM event_auth_chains;
4. Mark the deltas as not run (so they will re-run on upgrade).
.. code:: sql
DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py";
DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql";
5. Downgrade Synapse by following the instructions for your installation method
in the "Rolling back to older versions" section above.
Upgrading to v1.25.0
====================

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
* Remove dependency on `python3-distutils`.
-- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
matrix-synapse-py3 (1.25.0) stable; urgency=medium
[ Dan Callahan ]

1
debian/control vendored
View file

@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
Depends:
adduser,
debconf,
python3-distutils|libpython3-stdlib (<< 3.6),
${misc:Depends},
${shlibs:Depends},
${synapse:pydepends},

View file

@ -1,59 +0,0 @@
import argparse
import BaseHTTPServer
import os
import SimpleHTTPServer
import cgi, logging
from daemonize import Daemonize
class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
UPLOAD_PATH = "upload"
"""
Accept all post request as file upload
"""
def do_POST(self):
path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
length = self.headers["content-length"]
data = self.rfile.read(int(length))
with open(path, "wb") as fh:
fh.write(data)
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
# Return the absolute path of the uploaded file
self.wfile.write('{"url":"/%s"}' % path)
def setup():
parser = argparse.ArgumentParser()
parser.add_argument("directory")
parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
args = parser.parse_args()
# Get absolute path to directory to serve, as daemonize changes to '/'
os.chdir(args.directory)
dr = os.getcwd()
httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
def run():
os.chdir(dr)
httpd.serve_forever()
daemon = Daemonize(
app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
)
daemon.start()
if __name__ == "__main__":
setup()

View file

@ -198,12 +198,10 @@ old_signing_keys: {}
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
trusted_key_servers:
- server_name: matrix.org
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
"ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
password_config:
enabled: true

View file

@ -4,6 +4,7 @@
* [Quarantining media by ID](#quarantining-media-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room)
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
* [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
- [Delete local media](#delete-local-media)
* [Delete a specific local media](#delete-a-specific-local-media)
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
@ -123,6 +124,29 @@ The following fields are returned in the JSON response body:
* `num_quarantined`: integer - The number of media items successfully quarantined
## Protecting media from being quarantined
This API protects a single piece of local media from being quarantined using the
above APIs. This is useful for sticker packs and other shared media which you do
not want to get quarantined, especially when
[quarantining media in a room](#quarantining-media-in-a-room).
Request:
```
POST /_synapse/admin/v1/media/protect/<media_id>
{}
```
Where `media_id` is in the form of `abcdefg12345...`.
Response:
```json
{}
```
# Delete local media
This API deletes the *local* media from the disk of your own server.
This includes any local thumbnails and copies of media downloaded from

View file

@ -98,6 +98,8 @@ Body parameters:
- ``deactivated``, optional. If unspecified, deactivation state will be left
unchanged on existing accounts and set to ``false`` for new accounts.
A user cannot be erased by deactivating with this API. For details on deactivating users see
`Deactivate Account <#deactivate-account>`_.
If the user already exists then optional parameters default to the current value.
@ -248,6 +250,25 @@ server admin: see `README.rst <README.rst>`_.
The erase parameter is optional and defaults to ``false``.
An empty body may be passed for backwards compatibility.
The following actions are performed when deactivating an user:
- Try to unpind 3PIDs from the identity server
- Remove all 3PIDs from the homeserver
- Delete all devices and E2EE keys
- Delete all access tokens
- Delete the password hash
- Removal from all rooms the user is a member of
- Remove the user from the user directory
- Reject all pending invites
- Remove all account validity information related to the user
The following additional actions are performed during deactivation if``erase``
is set to ``true``:
- Remove the user's display name
- Remove the user's avatar URL
- Mark the user as erased
Reset password
==============
@ -337,6 +358,10 @@ A response body like the following is returned:
"total": 2
}
The server returns the list of rooms of which the user and the server
are member. If the user is local, all the rooms of which the user is
member are returned.
**Parameters**
The following parameters should be set in the URL:

32
docs/auth_chain_diff.dot Normal file
View file

@ -0,0 +1,32 @@
digraph auth {
nodesep=0.5;
rankdir="RL";
C [label="Create (1,1)"];
BJ [label="Bob's Join (2,1)", color=red];
BJ2 [label="Bob's Join (2,2)", color=red];
BJ2 -> BJ [color=red, dir=none];
subgraph cluster_foo {
A1 [label="Alice's invite (4,1)", color=blue];
A2 [label="Alice's Join (4,2)", color=blue];
A3 [label="Alice's Join (4,3)", color=blue];
A3 -> A2 -> A1 [color=blue, dir=none];
color=none;
}
PL1 [label="Power Level (3,1)", color=darkgreen];
PL2 [label="Power Level (3,2)", color=darkgreen];
PL2 -> PL1 [color=darkgreen, dir=none];
{rank = same; C; BJ; PL1; A1;}
A1 -> C [color=grey];
A1 -> BJ [color=grey];
PL1 -> C [color=grey];
BJ2 -> PL1 [penwidth=2];
A3 -> PL2 [penwidth=2];
A1 -> PL1 -> BJ -> C [penwidth=2];
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View file

@ -0,0 +1,108 @@
# Auth Chain Difference Algorithm
The auth chain difference algorithm is used by V2 state resolution, where a
naive implementation can be a significant source of CPU and DB usage.
### Definitions
A *state set* is a set of state events; e.g. the input of a state resolution
algorithm is a collection of state sets.
The *auth chain* of a set of events are all the events' auth events and *their*
auth events, recursively (i.e. the events reachable by walking the graph induced
by an event's auth events links).
The *auth chain difference* of a collection of state sets is the union minus the
intersection of the sets of auth chains corresponding to the state sets, i.e an
event is in the auth chain difference if it is reachable by walking the auth
event graph from at least one of the state sets but not from *all* of the state
sets.
## Breadth First Walk Algorithm
A way of calculating the auth chain difference without calculating the full auth
chains for each state set is to do a parallel breadth first walk (ordered by
depth) of each state set's auth chain. By tracking which events are reachable
from each state set we can finish early if every pending event is reachable from
every state set.
This can work well for state sets that have a small auth chain difference, but
can be very inefficient for larger differences. However, this algorithm is still
used if we don't have a chain cover index for the room (e.g. because we're in
the process of indexing it).
## Chain Cover Index
Synapse computes auth chain differences by pre-computing a "chain cover" index
for the auth chain in a room, allowing efficient reachability queries like "is
event A in the auth chain of event B". This is done by assigning every event a
*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
is in the auth chain of `B`) if and only if either:
1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
sequence number; or
2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
`L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
There are actually two potential implementations, one where we store links from
each chain to every other reachable chain (the transitive closure of the links
graph), and one where we remove redundant links (the transitive reduction of the
links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
would not be stored. Synapse uses the former implementations so that it doesn't
need to recurse to test reachability between chains.
### Example
An example auth graph would look like the following, where chains have been
formed based on type/state_key and are denoted by colour and are labelled with
`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
are those that would be remove in the second implementation described above).
![Example](auth_chain_diff.dot.png)
Note that we don't include all links between events and their auth events, as
most of those links would be redundant. For example, all events point to the
create event, but each chain only needs the one link from it's base to the
create event.
## Using the Index
This index can be used to calculate the auth chain difference of the state sets
by looking at the chain ID and sequence numbers reachable from each state set:
1. For every state set lookup the chain ID/sequence numbers of each state event
2. Use the index to find all chains and the maximum sequence number reachable
from each state set.
3. The auth chain difference is then all events in each chain that have sequence
numbers between the maximum sequence number reachable from *any* state set and
the minimum reachable by *all* state sets (if any).
Note that steps 2 is effectively calculating the auth chain for each state set
(in terms of chain IDs and sequence numbers), and step 3 is calculating the
difference between the union and intersection of the auth chains.
### Worked Example
For example, given the above graph, we can calculate the difference between
state sets consisting of:
1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
Using the index we see that the following auth chains are reachable from each
state set:
1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
And so, for each the ranges that are in the auth chain difference:
1. Chain 1: None, (since everything can reach the create event).
2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
sets and the maximum reachable is `2` (corresponding to Bob's second join).
3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
level).
4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
So the final result is: Bob's second join `(2,2)`, the second power level
`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.

View file

@ -42,11 +42,10 @@ as follows:
* For other installation mechanisms, see the documentation provided by the
maintainer.
To enable the OpenID integration, you should then add an `oidc_config` section
to your configuration file (or uncomment the `enabled: true` line in the
existing section). See [sample_config.yaml](./sample_config.yaml) for some
sample settings, as well as the text below for example configurations for
specific providers.
To enable the OpenID integration, you should then add a section to the `oidc_providers`
setting in your configuration file (or uncomment one of the existing examples).
See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
the text below for example configurations for specific providers.
## Sample configs
@ -62,8 +61,9 @@ Directory (tenant) ID as it will be used in the Azure links.
Edit your Synapse config file and change the `oidc_config` section:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: microsoft
idp_name: Microsoft
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
client_id: "<client id>"
client_secret: "<client secret>"
@ -103,8 +103,9 @@ Run with `dex serve examples/config-dev.yaml`.
Synapse config:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: dex
idp_name: "My Dex server"
skip_verification: true # This is needed as Dex is served on an insecure endpoint
issuer: "http://127.0.0.1:5556/dex"
client_id: "synapse"
@ -152,12 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
8. Copy Secret
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: keycloak
idp_name: "My KeyCloak server"
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
client_id: "synapse"
client_secret: "copy secret generated from above"
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
### [Auth0][auth0]
@ -187,8 +193,9 @@ oidc_config:
Synapse config:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: auth0
idp_name: Auth0
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@ -215,8 +222,9 @@ does not return a `sub` property, an alternative `subject_claim` has to be set.
Synapse config:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: github
idp_name: Github
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
@ -239,8 +247,9 @@ oidc_config:
2. add an "OAuth Client ID" for a Web Application under "Credentials".
3. Copy the Client ID and Client Secret, and add the following to your synapse config:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: google
idp_name: Google
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@ -262,8 +271,9 @@ oidc_config:
Synapse config:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: twitch
idp_name: Twitch
issuer: "https://id.twitch.tv/oauth2/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@ -283,8 +293,9 @@ oidc_config:
Synapse config:
```yaml
oidc_config:
enabled: true
oidc_providers:
- idp_id: gitlab
idp_name: Gitlab
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED

View file

@ -18,7 +18,7 @@ connect to a postgres database.
virtualenv](../INSTALL.md#installing-from-source), you can install
the library with:
~/synapse/env/bin/pip install matrix-synapse[postgres]
~/synapse/env/bin/pip install "matrix-synapse[postgres]"
(substituting the path to your virtualenv for `~/synapse/env`, if
you used a different path). You will require the postgres

View file

@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid
#
#web_client_location: https://riot.example.com/
# The public-facing base URL that clients use to access this HS
# (not including _matrix/...). This is the same URL a user would
# enter into the 'custom HS URL' field on their client. If you
# use synapse with a reverse proxy, this should be the URL to reach
# synapse via the proxy.
# The public-facing base URL that clients use to access this Homeserver (not
# including _matrix/...). This is the same URL a user might enter into the
# 'Custom Homeserver URL' field on their client. If you use Synapse with a
# reverse proxy, this should be the URL to reach Synapse via the proxy.
# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
# 'listeners' below).
#
# If this is left unset, it defaults to 'https://<server_name>/'. (Note that
# that will not work unless you configure Synapse or a reverse-proxy to listen
# on port 443.)
#
#public_baseurl: https://example.com/
@ -1150,8 +1155,9 @@ account_validity:
# send an email to the account's email address with a renewal link. By
# default, no such emails are sent.
#
# If you enable this setting, you will also need to fill out the 'email' and
# 'public_baseurl' configuration sections.
# If you enable this setting, you will also need to fill out the 'email'
# configuration section. You should also check that 'public_baseurl' is set
# correctly.
#
#renew_at: 1w
@ -1242,8 +1248,7 @@ account_validity:
# The identity server which we suggest that clients should use when users log
# in on this server.
#
# (By default, no suggestion is made, so it is left up to the client.
# This setting is ignored unless public_baseurl is also set.)
# (By default, no suggestion is made, so it is left up to the client.)
#
#default_identity_server: https://matrix.org
@ -1268,8 +1273,6 @@ account_validity:
# by the Matrix Identity Service API specification:
# https://matrix.org/docs/spec/identity_service/latest
#
# If a delegate is specified, the config option public_baseurl must also be filled out.
#
account_threepid_delegates:
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
@ -1709,141 +1712,153 @@ saml2_config:
#idp_entityid: 'https://our_idp/entityid'
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
#
# Options for each entry include:
#
# idp_id: a unique identifier for this identity provider. Used internally
# by Synapse; should be a single word such as 'github'.
#
# Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
#
# idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
#
# idp_icon: An optional icon for this identity provider, which is presented
# by identity picker pages. If given, must be an MXC URI of the format
# mxc://<server-name>/<media-id>
#
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true.
#
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
# is enabled) to discover the provider's endpoints.
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
# scopes: list of scopes to request. This should normally include the "openid"
# scope. Defaults to ["openid"].
#
# authorization_endpoint: the oauth2 authorization endpoint. Required if
# provider discovery is disabled.
#
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
# disabled.
#
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
# disabled and the 'openid' scope is not requested.
#
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
# the 'openid' scope is used.
#
# skip_verification: set to 'true' to skip metadata verification. Use this if
# you are connecting to a provider that is not OpenID Connect compliant.
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
#
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
# match a pre-existing account instead of failing. This could be used if
# switching from password logins to OIDC. Defaults to false.
#
# user_mapping_provider: Configuration for how attributes returned from a OIDC
# provider are mapped onto a matrix user. This setting has the following
# sub-properties:
#
# module: The class name of a custom mapping module. Default is
# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
# config: Configuration for the mapping provider module. This section will
# be passed as a Python dictionary to the user mapping provider
# module's `parse_config` method.
#
# For the default provider, the following settings are available:
#
# sub: name of the claim containing a unique identifier for the
# user. Defaults to 'sub', which OpenID Connect compliant
# providers should provide.
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username.
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
#
# extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login.
# Note that these are non-standard and clients will ignore them
# without modifications.
#
# When rendering, the Jinja2 templates are given a 'user' variable,
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for some example configurations.
# for information on how to configure these options.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
# server. Defaults to false.
# For backwards compatibility, it is also possible to configure a single OIDC
# provider via an 'oidc_config' setting. This is now deprecated and admins are
# advised to migrate to the 'oidc_providers' format.
#
oidc_providers:
# Generic example
#
#enabled: true
#- idp_id: my_idp
# idp_name: "My OpenID provider"
# discover: false
# issuer: "https://accounts.example.com/"
# client_id: "provided-by-your-issuer"
# client_secret: "provided-by-your-issuer"
# client_auth_method: client_secret_post
# scopes: ["openid", "profile"]
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# token_endpoint: "https://accounts.example.com/oauth2/token"
# userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true
# Uncomment the following to disable use of the OIDC discovery mechanism to
# discover endpoints. Defaults to true.
# For use with Keycloak
#
#discover: false
#- idp_id: keycloak
# idp_name: Keycloak
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
# For use with Github
#
# Required if 'enabled' is true.
#
#issuer: "https://accounts.example.com/"
# oauth2 client id to use.
#
# Required if 'enabled' is true.
#
#client_id: "provided-by-your-issuer"
# oauth2 client secret to use.
#
# Required if 'enabled' is true.
#
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
#client_auth_method: client_secret_post
# list of scopes to request. This should normally include the "openid" scope.
# Defaults to ["openid"].
#
#scopes: ["openid", "profile"]
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
#
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# the oauth2 token endpoint. Required if provider discovery is disabled.
#
#token_endpoint: "https://accounts.example.com/oauth2/token"
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
#
#userinfo_endpoint: "https://accounts.example.com/userinfo"
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
#
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# Uncomment to skip metadata verification. Defaults to false.
#
# Use this if you are connecting to a provider that is not OpenID Connect
# compliant.
# Avoid this in production.
#
#skip_verification: true
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
#
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
#
#user_profile_method: "userinfo_endpoint"
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
#allow_existing_users: true
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
#module: mapping_provider.OidcMappingProvider
# Custom configuration values for the module. This section will be passed as
# a Python dictionary to the user mapping provider module's `parse_config`
# method.
#
# The examples below are intended for the default provider: they should be
# changed if using a custom provider.
#
config:
# name of the claim containing a unique identifier for the user.
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID.
#
# When rendering, this template is given the following variables:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
#
# If this is not set, the user will be prompted to choose their
# own username.
#
#localpart_template: "{{ user.preferred_username }}"
# Jinja2 template for the display name to set on first login.
#
# If unset, no displayname will be set.
#
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
# Jinja2 templates for extra attributes to send back to the client during
# login.
#
# Note that these are non-standard and clients will ignore them without modifications.
#
#extra_attributes:
#birthdate: "{{ user.birthdate }}"
#- idp_id: google
# idp_name: Google
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
# client_secret: "your-client-secret" # TO BE FILLED
# authorization_endpoint: "https://github.com/login/oauth/authorize"
# token_endpoint: "https://github.com/login/oauth/access_token"
# userinfo_endpoint: "https://api.github.com/user"
# scopes: ["read:user"]
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{ user.login }"
# display_name_template: "{ user.name }"
# Enable Central Authentication Service (CAS) for registration and login.
@ -1893,9 +1908,9 @@ sso:
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# If public_baseurl is set, then the login fallback page (used by clients
# that don't natively support the required login flows) is whitelisted in
# addition to any URLs in this list.
# The login fallback page (used by clients that don't natively support the
# required login flows) is automatically whitelisted in addition to any URLs
# in this list.
#
# By default, this list is empty.
#
@ -1909,6 +1924,31 @@ sso:
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to prompt the user to choose an Identity Provider during
# login: 'sso_login_idp_picker.html'.
#
# This is only used if multiple SSO Identity Providers are configured.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL that the user will be redirected to after
# login. Needs manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * server_name: the homeserver's name.
#
# * providers: a list of available Identity Providers. Each element is
# an object with the following attributes:
# * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP
#
# The rendered HTML page should contain a form which submits its results
# back as a GET request, with the following query parameters:
#
# * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
# to the template)
#
# * idp: the 'idp_id' of the chosen IDP.
#
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
@ -1944,6 +1984,14 @@ sso:
#
# This template has no additional variables.
#
# * HTML page shown after a user-interactive authentication session which
# does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
#
# When rendering, this template is given the following variables:
# * server_name: the homeserver's name.
# * user_id_to_verify: the MXID of the user that we are trying to
# validate.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#

View file

@ -31,7 +31,7 @@ There is no need for a separate configuration file for the master process.
1. Adjust synapse configuration files as above.
1. Copy the `*.service` and `*.target` files in [system](system) to
`/etc/systemd/system`.
1. Run `systemctl deamon-reload` to tell systemd to load the new unit files.
1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
1. Run `systemctl enable matrix-synapse.service`. This will configure the
synapse master process to be started as part of the `matrix-synapse.target`
target.

View file

@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be
running PostgreSQL.
See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
for a higher level overview.
## Main process/worker communication
The processes communicate with each other via a Synapse-specific protocol called
@ -56,7 +59,7 @@ The appropriate dependencies must also be installed for Synapse. If using a
virtualenv, these can be installed with:
```sh
pip install matrix-synapse[redis]
pip install "matrix-synapse[redis]"
```
Note that these dependencies are included when synapse is installed with `pip
@ -214,6 +217,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
^/_matrix/client/(api/v1|r0|unstable)/devices$
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
^/_matrix/client/versions$

View file

@ -100,9 +100,11 @@ files =
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/metrics.py,
synapse/util/stringutils.py,
tests/replication,
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py

View file

@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"],
"rooms": ["is_public", "has_auth_chain_index"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
@ -629,6 +629,7 @@ class Porter(object):
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
await self._setup_events_stream_seqs()
await self._setup_device_inbox_seq()
# Step 3. Get tables.
self.progress.set_state("Fetching tables")
@ -911,6 +912,32 @@ class Porter(object):
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
)
async def _setup_device_inbox_seq(self):
"""Set the device inbox sequence to the correct value.
"""
curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="device_inbox",
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
)
curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="device_federation_outbox",
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
)
next_id = max(curr_local_id, curr_federation_id) + 1
def r(txn):
txn.execute(
"ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
)
return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
##############################################
# The following is simply UI stuff

View file

@ -15,16 +15,7 @@
# Stub for frozendict.
from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.

View file

@ -7,17 +7,17 @@ from typing import (
Callable,
Dict,
Hashable,
Iterator,
Iterable,
ItemsView,
Iterable,
Iterator,
KeysView,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Tuple,
Union,
ValuesView,
overload,

View file

@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi.
"""
from typing import List, Optional, Union, Type
from typing import List, Optional, Type, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...

View file

@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.25.0"
__version__ = "1.26.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
@ -186,8 +187,8 @@ class Auth:
AuthError if access is denied for the user in the access token
"""
try:
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.get_user_agent("")
ip_addr = request.getClientIP()
user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request)
@ -275,7 +276,7 @@ class Auth:
return None, None
if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request))
ip_address = IPAddress(request.getClientIP())
if ip_address not in app_service.ip_range_whitelist:
return None, None

View file

@ -51,11 +51,11 @@ class RoomDisposition:
class RoomVersion:
"""An object which describes the unique attributes of a room version."""
identifier = attr.ib() # str; the identifier for this version
disposition = attr.ib() # str; one of the RoomDispositions
event_format = attr.ib() # int; one of the EventFormatVersions
state_res = attr.ib() # int; one of the StateResolutionVersions
enforce_key_validity = attr.ib() # bool
identifier = attr.ib(type=str) # the identifier for this version
disposition = attr.ib(type=str) # one of the RoomDispositions
event_format = attr.ib(type=int) # one of the EventFormatVersions
state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool)
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
@ -64,9 +64,11 @@ class RoomVersion:
# * Floats
# * NaN, Infinity, -Infinity
strict_canonicaljson = attr.ib(type=bool)
# bool: MSC2209: Check 'notifications' key while verifying
# MSC2209: Check 'notifications' key while verifying
# m.room.power_levels auth rules.
limit_notifications_power_levels = attr.ib(type=bool)
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
msc2176_redaction_rules = attr.ib(type=bool)
class RoomVersions:
@ -79,6 +81,7 @@ class RoomVersions:
special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
)
V2 = RoomVersion(
"2",
@ -89,6 +92,7 @@ class RoomVersions:
special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
)
V3 = RoomVersion(
"3",
@ -99,6 +103,7 @@ class RoomVersions:
special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
)
V4 = RoomVersion(
"4",
@ -109,6 +114,7 @@ class RoomVersions:
special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
)
V5 = RoomVersion(
"5",
@ -119,6 +125,7 @@ class RoomVersions:
special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
)
V6 = RoomVersion(
"6",
@ -129,6 +136,18 @@ class RoomVersions:
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
)
@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
RoomVersions.V6,
RoomVersions.MSC2176,
)
} # type: Dict[str, RoomVersion]

View file

@ -42,8 +42,6 @@ class ConsentURIBuilder:
"""
if hs_config.form_secret is None:
raise ConfigError("form_secret not set in config")
if hs_config.public_baseurl is None:
raise ConfigError("public_baseurl not set in config")
self._hmac_secret = hs_config.form_secret.encode("utf-8")
self._public_baseurl = hs_config.public_baseurl

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
# Copyright 2019-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,7 +20,7 @@ import signal
import socket
import sys
import traceback
from typing import Iterable
from typing import Awaitable, Callable, Iterable
from typing_extensions import NoReturn
@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn:
sys.exit(1)
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
"""Register a callback with the reactor, to be called once it is running
This can be used to initialise parts of the system which require an asynchronous
setup.
Any exception raised by the callback will be printed and logged, and the process
will exit.
"""
async def wrapper():
try:
await cb(*args, **kwargs)
except Exception:
# previously, we used Failure().printTraceback() here, in the hope that
# would give better tracebacks than traceback.print_exc(). However, that
# doesn't handle chained exceptions (with a __cause__ or __context__) well,
# and I *think* the need for Failure() is reduced now that we mostly use
# async/await.
# Write the exception to both the logs *and* the unredirected stderr,
# because people tend to get confused if it only goes to one or the other.
#
# One problem with this is that if people are using a logging config that
# logs to the console (as is common eg under docker), they will get two
# copies of the exception. We could maybe try to detect that, but it's
# probably a cost we can bear.
logger.fatal("Error during startup", exc_info=True)
print("Error during startup:", file=sys.__stderr__)
traceback.print_exc(file=sys.__stderr__)
# it's no use calling sys.exit here, since that just raises a SystemExit
# exception which is then caught by the reactor, and everything carries
# on as normal.
os._exit(1)
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses, port):
"""
Start Prometheus metrics server.
@ -227,7 +267,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
"""
Start a Synapse server or worker.
@ -241,10 +281,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs: homeserver instance
listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
try:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
reactor = hs.get_reactor()
@wrap_as_background_process("sighup")
@ -304,12 +342,6 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
if sys.version_info >= (3, 7):
gc.collect()
gc.freeze()
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
if reactor.running:
reactor.stop()
sys.exit(1)
def setup_sentry(hs):

View file

@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
from twisted.internet import address, reactor
from twisted.internet import address
import synapse
import synapse.events
@ -34,6 +34,7 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
from synapse.app._base import register_start
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
@ -99,21 +100,37 @@ from synapse.rest.client.v1.profile import (
)
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
from synapse.rest.client.v2_alpha import (
account_data,
groups,
read_marker,
receipts,
room_keys,
sync,
tags,
user_directory,
)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import (
AccountDataServlet,
RoomAccountDataServlet,
)
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
from synapse.rest.client.v2_alpha.keys import (
KeyChangesServlet,
KeyQueryServlet,
OneTimeKeyServlet,
)
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
@ -445,6 +462,7 @@ class GenericWorkerSlavedStore(
UserDirectoryStore,
StatsStore,
UIAuthWorkerStore,
EndToEndRoomKeyStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
@ -501,7 +519,9 @@ class GenericWorkerServer(HomeServer):
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
OneTimeKeyServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
VoipRestServlet(self).register(resource)
PushRuleRestServlet(self).register(resource)
@ -519,6 +539,13 @@ class GenericWorkerServer(HomeServer):
room.register_servlets(self, resource, True)
room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
room_keys.register_servlets(self, resource)
tags.register_servlets(self, resource)
account_data.register_servlets(self, resource)
receipts.register_servlets(self, resource)
read_marker.register_servlets(self, resource)
SendToDeviceRestServlet(self).register(resource)
user_directory.register_servlets(self, resource)
@ -957,9 +984,7 @@ def start(config_options):
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, hs, config.worker_listeners
)
register_start(_base.start, hs, config.worker_listeners)
_base.start_worker_reactor("synapse-generic-worker", config)

View file

@ -15,15 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import os
import sys
from typing import Iterable, Iterator
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -40,7 +37,7 @@ from synapse.api.urls import (
WEB_CLIENT_PREFIX,
)
from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
from synapse.config._base import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
@ -63,6 +60,7 @@ from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
@ -72,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.homeserver")
@ -194,6 +191,7 @@ class SynapseHomeServer(HomeServer):
"/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self),
"/_synapse/client/pick_username": pick_username_resource(self),
"/_synapse/client/pick_idp": PickIdpResource(self),
}
)
@ -415,7 +413,6 @@ def setup(config_options):
_base.refresh_certificate(hs)
async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
@ -432,23 +429,12 @@ def setup(config_options):
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
await oidc.load_jwks()
_base.start(hs, config.listeners)
await _base.start(hs, config.listeners)
hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
# this gives better tracebacks than traceback.print_exc()
Failure().printTraceback(file=sys.stderr)
if reactor.running:
reactor.stop()
sys.exit(1)
reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
register_start(start)
return hs
@ -485,25 +471,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
e = e.__cause__
class SynapseService(service.Service):
"""
A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
def __init__(self, config):
self.config = config
def startService(self):
hs = setup(self.config)
change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
def stopService(self):
return self._port.stopListening()
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:

View file

@ -252,10 +252,11 @@ class Config:
env = jinja2.Environment(loader=loader, autoescape=autoescape)
# Update the environment with our custom filters
env.filters.update({"format_ts": _format_ts_filter})
if self.public_baseurl:
env.filters.update(
{"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
{
"format_ts": _format_ts_filter,
"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
}
)
for filename in filenames:

View file

@ -56,7 +56,7 @@ def json_error_to_config_error(
"""
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
for p in list(e.absolute_path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:

View file

@ -40,7 +40,7 @@ class CasConfig(Config):
self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
return """\
# Enable Central Authentication Service (CAS) for registration and login.
#
cas_config:

View file

@ -166,11 +166,6 @@ class EmailConfig(Config):
if not self.email_notif_from:
missing.append("email.notif_from")
# public_baseurl is required to build password reset and validation links that
# will be emailed to users
if config.get("public_baseurl") is None:
missing.append("public_baseurl")
if missing:
raise ConfigError(
MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
@ -269,9 +264,6 @@ class EmailConfig(Config):
if not self.email_notif_from:
missing.append("email.notif_from")
if config.get("public_baseurl") is None:
missing.append("public_baseurl")
if missing:
raise ConfigError(
"email.enable_notifs is True but required keys are missing: %s"

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,8 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import string
from typing import Iterable, Optional, Tuple, Type
import attr
from synapse.config._util import validate_config
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
from ._base import Config, ConfigError
@ -25,48 +34,285 @@ class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
self.oidc_enabled = False
oidc_config = config.get("oidc_config")
if not oidc_config or not oidc_config.get("enabled", False):
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
return
try:
check_requirements("oidc")
except DependencyException as e:
raise ConfigError(e.message)
raise ConfigError(e.message) from e
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
self.oidc_enabled = True
self.oidc_discover = oidc_config.get("discover", True)
self.oidc_issuer = oidc_config["issuer"]
self.oidc_client_id = oidc_config["client_id"]
self.oidc_client_secret = oidc_config["client_secret"]
self.oidc_client_auth_method = oidc_config.get(
"client_auth_method", "client_secret_basic"
)
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
@property
def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider
return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
#
# Options for each entry include:
#
# idp_id: a unique identifier for this identity provider. Used internally
# by Synapse; should be a single word such as 'github'.
#
# Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
#
# idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
#
# idp_icon: An optional icon for this identity provider, which is presented
# by identity picker pages. If given, must be an MXC URI of the format
# mxc://<server-name>/<media-id>
#
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true.
#
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
# is enabled) to discover the provider's endpoints.
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
# scopes: list of scopes to request. This should normally include the "openid"
# scope. Defaults to ["openid"].
#
# authorization_endpoint: the oauth2 authorization endpoint. Required if
# provider discovery is disabled.
#
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
# disabled.
#
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
# disabled and the 'openid' scope is not requested.
#
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
# the 'openid' scope is used.
#
# skip_verification: set to 'true' to skip metadata verification. Use this if
# you are connecting to a provider that is not OpenID Connect compliant.
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
#
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
# match a pre-existing account instead of failing. This could be used if
# switching from password logins to OIDC. Defaults to false.
#
# user_mapping_provider: Configuration for how attributes returned from a OIDC
# provider are mapped onto a matrix user. This setting has the following
# sub-properties:
#
# module: The class name of a custom mapping module. Default is
# {mapping_provider!r}.
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
# config: Configuration for the mapping provider module. This section will
# be passed as a Python dictionary to the user mapping provider
# module's `parse_config` method.
#
# For the default provider, the following settings are available:
#
# sub: name of the claim containing a unique identifier for the
# user. Defaults to 'sub', which OpenID Connect compliant
# providers should provide.
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username.
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
#
# extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login.
# Note that these are non-standard and clients will ignore them
# without modifications.
#
# When rendering, the Jinja2 templates are given a 'user' variable,
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for information on how to configure these options.
#
# For backwards compatibility, it is also possible to configure a single OIDC
# provider via an 'oidc_config' setting. This is now deprecated and admins are
# advised to migrate to the 'oidc_providers' format.
#
oidc_providers:
# Generic example
#
#- idp_id: my_idp
# idp_name: "My OpenID provider"
# discover: false
# issuer: "https://accounts.example.com/"
# client_id: "provided-by-your-issuer"
# client_secret: "provided-by-your-issuer"
# client_auth_method: client_secret_post
# scopes: ["openid", "profile"]
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# token_endpoint: "https://accounts.example.com/oauth2/token"
# userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true
# For use with Keycloak
#
#- idp_id: keycloak
# idp_name: Keycloak
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
# For use with Github
#
#- idp_id: google
# idp_name: Google
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
# client_secret: "your-client-secret" # TO BE FILLED
# authorization_endpoint: "https://github.com/login/oauth/authorize"
# token_endpoint: "https://github.com/login/oauth/access_token"
# userinfo_endpoint: "https://api.github.com/user"
# scopes: ["read:user"]
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{{ user.login }}"
# display_name_template: "{{ user.name }}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
"required": ["issuer", "client_id", "client_secret"],
"properties": {
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
"idp_name": {"type": "string"},
"idp_icon": {"type": "string"},
"discover": {"type": "boolean"},
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
# to avoid importing authlib here.
"enum": ["client_secret_basic", "client_secret_post", "none"],
},
"scopes": {"type": "array", "items": {"type": "string"}},
"authorization_endpoint": {"type": "string"},
"token_endpoint": {"type": "string"},
"userinfo_endpoint": {"type": "string"},
"jwks_uri": {"type": "string"},
"skip_verification": {"type": "boolean"},
"user_profile_method": {
"type": "string",
"enum": ["auto", "userinfo_endpoint"],
},
"allow_existing_users": {"type": "boolean"},
"user_mapping_provider": {"type": ["object", "null"]},
},
}
# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
}
# the `oidc_providers` list can either be None (as it is in the default config), or
# a list of provider configs, each of which requires an explicit ID and name.
OIDC_PROVIDER_LIST_SCHEMA = {
"oneOf": [
{"type": "null"},
{"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
]
}
# the `oidc_config` setting can either be None (which it used to be in the default
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
# property.
#
# It's *possible* to represent this with jsonschema, but the resultant errors aren't
# particularly clear, so we just check for either an object or a null here, and do
# additional checks in the code.
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
MAIN_CONFIG_SCHEMA = {
"type": "object",
"properties": {
"oidc_config": OIDC_CONFIG_SCHEMA,
"oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
},
}
def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
"""extract and parse the OIDC provider configs from the config dict
The configuration may contain either a single `oidc_config` object with an
`enabled: True` property, or a list of provider configurations under
`oidc_providers`, *or both*.
Returns a generator which yields the OidcProviderConfig objects
"""
validate_config(MAIN_CONFIG_SCHEMA, config, ())
for i, p in enumerate(config.get("oidc_providers") or []):
yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,)))
# for backwards-compatibility, it is also possible to provide a single "oidc_config"
# object with an "enabled: True" property.
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
# MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
# it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
# above), so now we need to validate it.
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
yield _parse_oidc_config_dict(oidc_config, ("oidc_config",))
def _parse_oidc_config_dict(
oidc_config: JsonDict, config_path: Tuple[str, ...]
) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig
Raises:
ConfigError if the configuration is malformed.
"""
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
ump_config, config_path + ("user_mapping_provider",)
)
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
@ -76,151 +322,124 @@ class OIDCConfig(Config):
missing_methods = [
method
for method in required_methods
if not hasattr(self.oidc_user_mapping_provider_class, method)
if not hasattr(user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
"Class %s is missing required "
"methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
config_path + ("user_mapping_provider", "module"),
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for some example configurations.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
# server. Defaults to false.
#
#enabled: true
# MSC2858 will apply certain limits in what can be used as an IdP id, so let's
# enforce those limits now.
# TODO: factor out this stuff to a generic function
idp_id = oidc_config.get("idp_id", "oidc")
valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._")
# Uncomment the following to disable use of the OIDC discovery mechanism to
# discover endpoints. Defaults to true.
#
#discover: false
if any(c not in valid_idp_chars for c in idp_id):
raise ConfigError(
'idp_id may only contain a-z, 0-9, "-", ".", "_"',
config_path + ("idp_id",),
)
if idp_id[0] not in string.ascii_lowercase:
raise ConfigError(
"idp_id must start with a-z", config_path + ("idp_id",),
)
# MSC2858 also specifies that the idp_icon must be a valid MXC uri
idp_icon = oidc_config.get("idp_icon")
if idp_icon is not None:
try:
parse_and_validate_mxc_uri(idp_icon)
except ValueError as e:
raise ConfigError(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"],
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
token_endpoint=oidc_config.get("token_endpoint"),
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
jwks_uri=oidc_config.get("jwks_uri"),
skip_verification=oidc_config.get("skip_verification", False),
user_profile_method=oidc_config.get("user_profile_method", "auto"),
allow_existing_users=oidc_config.get("allow_existing_users", False),
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
)
@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
idp_id = attr.ib(type=str)
# user-facing name for this identity provider.
idp_name = attr.ib(type=str)
# Optional MXC URI for icon for this IdP.
idp_icon = attr.ib(type=Optional[str])
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
#
# Required if 'enabled' is true.
#
#issuer: "https://accounts.example.com/"
issuer = attr.ib(type=str)
# oauth2 client id to use.
#
# Required if 'enabled' is true.
#
#client_id: "provided-by-your-issuer"
# oauth2 client id to use
client_id = attr.ib(type=str)
# oauth2 client secret to use.
#
# Required if 'enabled' is true.
#
#client_secret: "provided-by-your-issuer"
# oauth2 client secret to use
client_secret = attr.ib(type=str)
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
# Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'.
#
#client_auth_method: client_secret_post
client_auth_method = attr.ib(type=str)
# list of scopes to request. This should normally include the "openid" scope.
# Defaults to ["openid"].
#
#scopes: ["openid", "profile"]
# list of scopes to request
scopes = attr.ib(type=Collection[str])
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
#
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint = attr.ib(type=Optional[str])
# the oauth2 token endpoint. Required if provider discovery is disabled.
#
#token_endpoint: "https://accounts.example.com/oauth2/token"
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint = attr.ib(type=Optional[str])
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
#
#userinfo_endpoint: "https://accounts.example.com/userinfo"
userinfo_endpoint = attr.ib(type=Optional[str])
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
#
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
jwks_uri = attr.ib(type=Optional[str])
# Uncomment to skip metadata verification. Defaults to false.
#
# Use this if you are connecting to a provider that is not OpenID Connect
# compliant.
# Avoid this in production.
#
#skip_verification: true
# Whether to skip metadata verification
skip_verification = attr.ib(type=bool)
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
#
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
#
#user_profile_method: "userinfo_endpoint"
user_profile_method = attr.ib(type=str)
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
#allow_existing_users: true
# whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing
allow_existing_users = attr.ib(type=bool)
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is {mapping_provider!r}.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
#module: mapping_provider.OidcMappingProvider
# the class of the user mapping provider
user_mapping_provider_class = attr.ib(type=Type)
# Custom configuration values for the module. This section will be passed as
# a Python dictionary to the user mapping provider module's `parse_config`
# method.
#
# The examples below are intended for the default provider: they should be
# changed if using a custom provider.
#
config:
# name of the claim containing a unique identifier for the user.
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID.
#
# When rendering, this template is given the following variables:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
#
# If this is not set, the user will be prompted to choose their
# own username.
#
#localpart_template: "{{{{ user.preferred_username }}}}"
# Jinja2 template for the display name to set on first login.
#
# If unset, no displayname will be set.
#
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
# Jinja2 templates for extra attributes to send back to the client during
# login.
#
# Note that these are non-standard and clients will ignore them without modifications.
#
#extra_attributes:
#birthdate: "{{{{ user.birthdate }}}}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
# the config of the user mapping provider
user_mapping_provider_config = attr.ib()

View file

@ -14,14 +14,13 @@
# limitations under the License.
import os
from distutils.util import strtobool
import pkg_resources
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols
from synapse.util.stringutils import random_string_with_symbols, strtobool
class AccountValidityConfig(Config):
@ -50,10 +49,6 @@ class AccountValidityConfig(Config):
self.startup_job_max_delta = self.period * 10.0 / 100.0
if self.renew_by_email_enabled:
if "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
template_dir = config.get("template_dir")
if not template_dir:
@ -86,12 +81,12 @@ class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
self.enable_registration = strtobool(
str(config.get("enable_registration", False))
)
if "disable_registration" in config:
self.enable_registration = not bool(
strtobool(str(config["disable_registration"]))
self.enable_registration = not strtobool(
str(config["disable_registration"])
)
self.account_validity = AccountValidityConfig(
@ -110,13 +105,6 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
if self.account_threepid_delegate_msisdn and not self.public_baseurl:
raise ConfigError(
"The configuration option `public_baseurl` is required if "
"`account_threepid_delegate.msisdn` is set, such that "
"clients know where to submit validation tokens to. Please "
"configure `public_baseurl`."
)
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
@ -241,8 +229,9 @@ class RegistrationConfig(Config):
# send an email to the account's email address with a renewal link. By
# default, no such emails are sent.
#
# If you enable this setting, you will also need to fill out the 'email' and
# 'public_baseurl' configuration sections.
# If you enable this setting, you will also need to fill out the 'email'
# configuration section. You should also check that 'public_baseurl' is set
# correctly.
#
#renew_at: 1w
@ -333,8 +322,7 @@ class RegistrationConfig(Config):
# The identity server which we suggest that clients should use when users log
# in on this server.
#
# (By default, no suggestion is made, so it is left up to the client.
# This setting is ignored unless public_baseurl is also set.)
# (By default, no suggestion is made, so it is left up to the client.)
#
#default_identity_server: https://matrix.org
@ -359,8 +347,6 @@ class RegistrationConfig(Config):
# by the Matrix Identity Service API specification:
# https://matrix.org/docs/spec/identity_service/latest
#
# If a delegate is specified, the config option public_baseurl must also be filled out.
#
account_threepid_delegates:
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process

View file

@ -189,8 +189,6 @@ class SAML2Config(Config):
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)

View file

@ -26,7 +26,7 @@ import yaml
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError
@ -161,7 +161,11 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
self.public_baseurl = config.get("public_baseurl") or "https://%s/" % (
self.server_name,
)
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)
@ -317,9 +321,6 @@ class ServerConfig(Config):
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit,
@ -740,11 +741,16 @@ class ServerConfig(Config):
#
#web_client_location: https://riot.example.com/
# The public-facing base URL that clients use to access this HS
# (not including _matrix/...). This is the same URL a user would
# enter into the 'custom HS URL' field on their client. If you
# use synapse with a reverse proxy, this should be the URL to reach
# synapse via the proxy.
# The public-facing base URL that clients use to access this Homeserver (not
# including _matrix/...). This is the same URL a user might enter into the
# 'Custom Homeserver URL' field on their client. If you use Synapse with a
# reverse proxy, this should be the URL to reach Synapse via the proxy.
# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
# 'listeners' below).
#
# If this is left unset, it defaults to 'https://<server_name>/'. (Note that
# that will not work unless you configure Synapse or a reverse-proxy to listen
# on port 443.)
#
#public_baseurl: https://example.com/

View file

@ -31,18 +31,22 @@ class SSOConfig(Config):
# Read templates from disk
(
self.sso_login_idp_picker_template,
self.sso_redirect_confirm_template,
self.sso_auth_confirm_template,
self.sso_error_template,
sso_account_deactivated_template,
sso_auth_success_template,
self.sso_auth_bad_user_template,
) = self.read_templates(
[
"sso_login_idp_picker.html",
"sso_redirect_confirm.html",
"sso_auth_confirm.html",
"sso_error.html",
"sso_account_deactivated.html",
"sso_auth_success.html",
"sso_auth_bad_user.html",
],
template_dir,
)
@ -60,9 +64,6 @@ class SSOConfig(Config):
# gracefully to the client). This would make it pointless to ask the user for
# confirmation, since the URL the confirmation page would be showing wouldn't be
# the client's.
# public_baseurl is an optional setting, so we only add the fallback's URL to the
# list if it's provided (because we can't figure out what that URL is otherwise).
if self.public_baseurl:
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
self.sso_client_whitelist.append(login_fallback_url)
@ -82,9 +83,9 @@ class SSOConfig(Config):
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# If public_baseurl is set, then the login fallback page (used by clients
# that don't natively support the required login flows) is whitelisted in
# addition to any URLs in this list.
# The login fallback page (used by clients that don't natively support the
# required login flows) is automatically whitelisted in addition to any URLs
# in this list.
#
# By default, this list is empty.
#
@ -98,6 +99,31 @@ class SSOConfig(Config):
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to prompt the user to choose an Identity Provider during
# login: 'sso_login_idp_picker.html'.
#
# This is only used if multiple SSO Identity Providers are configured.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL that the user will be redirected to after
# login. Needs manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * server_name: the homeserver's name.
#
# * providers: a list of available Identity Providers. Each element is
# an object with the following attributes:
# * idp_id: unique identifier for the IdP
# * idp_name: user-facing name for the IdP
#
# The rendered HTML page should contain a form which submits its results
# back as a GET request, with the following query parameters:
#
# * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
# to the template)
#
# * idp: the 'idp_id' of the chosen IDP.
#
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
@ -133,6 +159,14 @@ class SSOConfig(Config):
#
# This template has no additional variables.
#
# * HTML page shown after a user-interactive authentication session which
# does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
#
# When rendering, this template is given the following variables:
# * server_name: the homeserver's name.
# * user_id_to_verify: the MXID of the user that we are trying to
# validate.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#

View file

@ -53,6 +53,15 @@ class WriterLocations:
default=["master"], type=List[str], converter=_instance_to_list_converter
)
typing = attr.ib(default="master", type=str)
to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
account_data = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
receipts = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
class WorkerConfig(Config):
@ -124,7 +133,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
for stream in ("events", "typing", "to_device", "account_data", "receipts"):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
@ -133,6 +142,21 @@ class WorkerConfig(Config):
% (instance, stream)
)
if len(self.writers.to_device) != 1:
raise ConfigError(
"Must only specify one instance to handle `to_device` messages."
)
if len(self.writers.account_data) != 1:
raise ConfigError(
"Must only specify one instance to handle `account_data` messages."
)
if len(self.writers.receipts) != 1:
raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
# Whether this worker should run background tasks or not.

View file

@ -17,7 +17,6 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself.
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))

View file

@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"state_key",
"depth",
"prev_events",
"prev_state",
"auth_events",
"origin",
"origin_server_ts",
"membership",
]
# Room versions from before MSC2176 had additional allowed keys.
if not room_version.msc2176_redaction_rules:
allowed_keys.extend(["prev_state", "membership"])
event_type = event_dict["type"]
new_content = {}
@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
if event_type == EventTypes.Member:
add_fields("membership")
elif event_type == EventTypes.Create:
# MSC2176 rules state that create events cannot be redacted.
if room_version.msc2176_redaction_rules:
return event_dict
add_fields("creator")
elif event_type == EventTypes.JoinRules:
add_fields("join_rule")
@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"kick",
"redact",
)
if room_version.msc2176_redaction_rules:
add_fields("invite")
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
add_fields("redacts")
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}

View file

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import (
TYPE_CHECKING,
Any,
@ -48,7 +49,6 @@ from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@ -65,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -860,8 +861,10 @@ class FederationHandlerRegistry:
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
# Map from type to instance name that we should route EDU handling to.
self._edu_type_to_instance = {} # type: Dict[str, str]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
# EDU received.
self._edu_type_to_instance = {} # type: Dict[str, List[str]]
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
@ -905,7 +908,12 @@ class FederationHandlerRegistry:
def register_instance_for_edu(self, edu_type: str, instance_name: str):
"""Register that the EDU handler is on a different instance than master.
"""
self._edu_type_to_instance[edu_type] = instance_name
self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
"""Register that the EDU handler is on multiple instances.
"""
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict):
if not self.config.use_presence and edu_type == "m.presence":
@ -924,8 +932,11 @@ class FederationHandlerRegistry:
return
# Check if we can route it somewhere else that isn't us
route_to = self._edu_type_to_instance.get(edu_type, "master")
if route_to != self._instance_name:
instances = self._edu_type_to_instance.get(edu_type, ["master"])
if self._instance_name not in instances:
# Pick an instance randomly so that we don't overload one.
route_to = random.choice(instances)
try:
await self._send_edu(
instance_name=route_to,

View file

@ -28,7 +28,6 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_boolean_from_args,
@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
)
from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
ReplicationRemoveTagRestServlet,
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_to_room(
user_id, room_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._room_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_for_user(
user_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._user_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
tag: The tag name to add.
content: A json object to associate with the tag.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_tag_to_room(
user_id, room_id, tag, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._add_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
content=content,
)
return response["max_stream_id"]
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.remove_tag_from_room(
user_id, room_id, tag
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._remove_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
)
return response["max_stream_id"]
class AccountDataEventSource:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

View file

@ -49,8 +49,13 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers._base import BaseHandler
from synapse.handlers.ui_auth import (
INTERACTIVE_AUTH_CHECKERS,
UIAuthSessionDataConstants,
)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@ -260,10 +263,6 @@ class AuthHandler(BaseHandler):
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@ -284,7 +283,6 @@ class AuthHandler(BaseHandler):
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
clientip: str,
description: str,
) -> Tuple[dict, Optional[str]]:
"""
@ -301,8 +299,6 @@ class AuthHandler(BaseHandler):
request_body: The body of the request sent by the client
clientip: The IP address of the client.
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
@ -338,10 +334,10 @@ class AuthHandler(BaseHandler):
request_body.pop("auth", None)
return request_body, None
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
@ -349,13 +345,16 @@ class AuthHandler(BaseHandler):
)
flows = [[login_type] for login_type in supported_ui_auth_types]
def get_new_session_data() -> JsonDict:
return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
try:
result, params, session_id = await self.check_ui_auth(
flows, request, request_body, clientip, description
flows, request, request_body, description, get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise
# find the completed login type
@ -363,14 +362,14 @@ class AuthHandler(BaseHandler):
if login_type not in result:
continue
user_id = result[login_type]
validated_user_id = result[login_type]
break
else:
# this can't happen
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth")
# Note that the access token has been validated.
@ -402,13 +401,9 @@ class AuthHandler(BaseHandler):
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
if await self.hs.get_sso_handler().get_identity_providers_for_user(
user.to_string()
):
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
@ -426,8 +421,8 @@ class AuthHandler(BaseHandler):
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientip: str,
description: str,
get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@ -448,11 +443,16 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip: The IP address of the client.
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
get_new_session_data:
an optional callback which will be called when starting a new session.
it should return data to be stored as part of the session.
The keys of the returned data should be entries in
UIAuthSessionDataConstants.
Returns:
A tuple of (creds, params, session_id).
@ -480,10 +480,15 @@ class AuthHandler(BaseHandler):
# If there's no session ID, create a new session.
if not sid:
new_session_data = get_new_session_data() if get_new_session_data else {}
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
for k, v in new_session_data.items():
await self.set_session_data(session.session_id, k, v)
else:
try:
session = await self.store.get_ui_auth_session(sid)
@ -539,7 +544,8 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = request.get_user_agent("")
user_agent = get_request_user_agent(request)
clientip = request.getClientIP()
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
@ -644,7 +650,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
key: The key to store the data under. An entry from
UIAuthSessionDataConstants.
value: The data to store
"""
try:
@ -660,7 +667,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
key: The key the data was stored under. An entry from
UIAuthSessionDataConstants.
default: Value to return if the key has not been set
"""
try:
@ -1334,12 +1342,12 @@ class AuthHandler(BaseHandler):
else:
return False
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Args:
redirect_url: The URL to redirect to the SSO provider.
request: The incoming HTTP request
session_id: The user interactive authentication session ID.
Returns:
@ -1349,31 +1357,39 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
user_id_to_verify = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
)
if not idps:
# we checked that the user had some remote identities before offering an SSO
# flow, so either it's been deleted or the client has requested SSO despite
# it not being offered.
raise SynapseError(400, "User has no SSO identities")
# for now, just pick one
idp_id, sso_auth_provider = next(iter(idps.items()))
if len(idps) > 0:
logger.warning(
"User %r has previously logged in with multiple SSO IdPs; arbitrarily "
"picking %r",
user_id_to_verify,
idp_id,
)
redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session_id
)
return self._sso_auth_confirm_template.render(
description=session.description, redirect_url=redirect_url,
)
async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: Request,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
session_id: The ID of the user-interactive auth session.
request: The request to complete.
"""
# Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
await self.store.mark_ui_auth_stage_complete(
session_id, LoginType.SSO, registered_user_id
)
# Render the HTML and return.
html = self._sso_auth_success_template
respond_with_html(request, 200, html)
async def complete_sso_login(
self,
registered_user_id: str,
@ -1488,8 +1504,8 @@ class AuthHandler(BaseHandler):
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({param_name: param})
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)

View file

@ -75,10 +75,19 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
# identifier for the external_ids table
self._auth_provider_id = "cas"
self.idp_id = "cas"
# user-facing name of this auth provider
self.idp_name = "CAS"
# we do not currently support icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@ -105,7 +114,7 @@ class CasHandler:
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.
Should be the same as those passed to `handle_redirect_request`.
Raises:
CasError: If there's an error parsing the CAS response.
@ -184,16 +193,31 @@ class CasHandler:
return CasResponse(user, attributes)
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
Generates a URL for the CAS server where the client should be redirected.
async def handle_redirect_request(
self,
request: SynapseRequest,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Generates a URL for the CAS server where the client should be redirected.
Args:
service_args: Additional arguments to include in the final redirect URL.
request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
The URL to redirect the client to.
URL to redirect to
"""
if ui_auth_session_id:
service_args = {"session": ui_auth_session_id}
else:
assert client_redirect_url
service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
args = urllib.parse.urlencode(
{"service": self._build_service_param(service_args)}
)
@ -275,7 +299,7 @@ class CasHandler:
# first check if we're doing a UIA
if session:
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, cas_response.username, session, request,
self.idp_id, cas_response.username, session, request,
)
# otherwise, we're handling a login request.
@ -375,7 +399,7 @@ class CasHandler:
return None
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
self.idp_id,
cas_response.username,
request,
client_redirect_url,

View file

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester
from synapse.types import Requester, UserID, create_requester
from ._base import BaseHandler
@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
async def deactivate_account(
self, user_id: str, erase_data: bool, id_server: Optional[str] = None
self,
user_id: str,
erase_data: bool,
requester: Requester,
id_server: Optional[str] = None,
by_admin: bool = False,
) -> bool:
"""Deactivate a user's account
Args:
user_id: ID of user to be deactivated
erase_data: whether to GDPR-erase the user's data
requester: The user attempting to make this change.
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
by_admin: Whether this change was made by an administrator.
Returns:
True if identity server supports removing threepids, otherwise False.
@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as erased, if they asked for that
if erase_data:
user = UserID.from_string(user_id)
# Remove avatar URL from this user
await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
# Remove displayname from this user
await self._profile_handler.set_displayname(user, requester, "", by_admin)
logger.info("Marking %s as erased", user_id)
await self.store.mark_user_erased(user_id)

View file

@ -24,6 +24,7 @@ from synapse.logging.opentracing import (
set_tag,
start_active_span,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@ -44,13 +45,37 @@ class DeviceMessageHandler:
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the to-device replication stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the to device EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.direct_to_device", hs.config.worker.writers.to_device,
)
self._device_list_updater = hs.get_device_handler().device_list_updater
# The handler to call when we think a user's device list might be out of
# sync. We do all device list resyncing on the master instance, so if
# we're on a worker we hit the device resync replication API.
if hs.config.worker.worker_app is None:
self._user_device_resync = (
hs.get_device_handler().device_list_updater.user_device_resync
)
else:
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
@ -138,9 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
run_in_background(
self._device_list_updater.user_device_resync, sender_user_id
)
run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
@ -195,7 +218,8 @@ class DeviceMessageHandler:
)
log_kv({"remote_messages": remote_messages})
if self.federation_sender:
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation.send_device_messages(destination)
self.federation_sender.send_device_messages(destination)

View file

@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler):
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
assert self.hs.config.public_baseurl
# we need to tell the client to send the token back to us, since it doesn't
# otherwise know where to send it, so add submit_url response parameter
# (see also MSC2078)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode
import attr
@ -35,7 +35,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.config.oidc_config import OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@ -71,6 +71,144 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
provider_confs = hs.config.oidc.oidc_providers
# we should not have been instantiated if there is no configured provider.
assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
} # type: Dict[str, OidcProvider]
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need.
"""
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
Once we know the session is legit, we then delegate to the OIDC Provider
implementation, which will exchange the code with the provider and complete the
login/authentication.
Args:
request: the incoming request from the browser.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except (MacaroonDeserializationException, ValueError) as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
oidc_provider = self._providers.get(session_data.idp_id)
if not oidc_provider:
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
return
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
code = request.args[b"code"][0].decode()
await oidc_provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
@ -85,44 +223,61 @@ class OidcError(Exception):
return self.error
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
class OidcProvider:
"""Wraps the config for a single OIDC IdentityProvider
Provides methods for handling redirect requests and callbacks via that particular
IdP.
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
def __init__(
self,
hs: "HomeServer",
token_generator: "OidcSessionTokenGenerator",
provider: OidcProviderConfig,
):
self._store = hs.get_datastore()
self._token_generator = token_generator
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
hs.config.oidc_client_auth_method,
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint,
token_endpoint=hs.config.oidc_token_endpoint,
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
jwks_uri=hs.config.oidc_jwks_uri,
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._provider_needs_discovery = provider.discover
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table
self._auth_provider_id = "oidc"
self.idp_id = provider.idp_id
# user-facing name of this auth provider
self.idp_name = provider.idp_name
# MXC URI for icon for this auth provider
self.idp_icon = provider.idp_icon
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
def _validate_metadata(self):
"""Verifies the provider metadata.
@ -475,7 +630,7 @@ class OidcHandler(BaseHandler):
async def handle_redirect_request(
self,
request: SynapseRequest,
client_redirect_url: bytes,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
@ -499,7 +654,7 @@ class OidcHandler(BaseHandler):
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
when everything is done
when everything is done (or None for UI Auth)
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
@ -511,11 +666,17 @@ class OidcHandler(BaseHandler):
state = generate_token()
nonce = generate_token()
cookie = self._generate_oidc_session_token(
if not client_redirect_url:
client_redirect_url = b""
cookie = self._token_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
),
)
request.addCookie(
SESSION_COOKIE_NAME,
@ -538,22 +699,16 @@ class OidcHandler(BaseHandler):
nonce=nonce,
)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
By this time we have already validated the session on the synapse side, and
now need to do the provider-specific operations. This includes:
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
- once we known this session is legit, exchange the code with the
provider using the ``token_endpoint`` (see ``_exchange_code``)
- exchange the code with the provider using the ``token_endpoint`` (see
``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
@ -563,88 +718,12 @@ class OidcHandler(BaseHandler):
Args:
request: the incoming request from the browser.
session_data: the session data, extracted from our cookie
code: The authorization code we got from the callback.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
(
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
logger.debug("Exchanging code")
code = request.args[b"code"][0].decode()
try:
logger.debug("Exchanging code")
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
@ -666,14 +745,14 @@ class OidcHandler(BaseHandler):
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# first check if we're doing a UIA
if ui_auth_session_id:
if session_data.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
@ -682,7 +761,7 @@ class OidcHandler(BaseHandler):
return
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
)
# otherwise, it's a login
@ -690,133 +769,12 @@ class OidcHandler(BaseHandler):
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, client_redirect_url
userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(
self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.clock.time_msec()
return now < expiry
async def _complete_oidc_login(
self,
userinfo: UserInfo,
@ -893,8 +851,8 @@ class OidcHandler(BaseHandler):
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
@ -923,7 +881,7 @@ class OidcHandler(BaseHandler):
extra_attributes = await get_extra_attributes(userinfo, token)
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
self.idp_id,
remote_user_id,
request,
client_redirect_url,
@ -946,6 +904,157 @@ class OidcHandler(BaseHandler):
return str(remote_user_id)
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._server_name = hs.hostname
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
def generate_oidc_session_token(
self,
state: str,
session_data: "OidcSessionData",
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
session_data: data to include in the session token.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def verify_oidc_session_token(
self, session: bytes, state: str
) -> "OidcSessionData":
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The data extracted from the session cookie
Raises:
ValueError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
)
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
ValueError: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
@attr.s(frozen=True, slots=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# the Identity Provider being used
idp_id = attr.ib(type=str)
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)

View file

@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
return result["displayname"]
return result.get("displayname")
async def set_displayname(
self,
@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
return result["avatar_url"]
return result.get("avatar_url")
async def set_avatar_url(
self,
@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
avatar_url_to_set = new_avatar_url # type: Optional[str]
if new_avatar_url == "":
avatar_url_to_set = None
# Same like set_displayname
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
await self.store.set_profile_avatar_url(
target_user.localpart, avatar_url_to_set
)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)

View file

@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update:
content = {"event_id": event_id}
max_id = await self.store.add_account_data_to_room(
await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View file

@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_federation_sender()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the receipts stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the receipt EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.receipt", hs.config.worker.writers.receipts,
)
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
if not is_new:
return
await self.federation.send_read_receipt(receipt)
if self.federation_sender:
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:

View file

@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@ -55,6 +54,7 @@ from synapse.types import (
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -365,7 +365,7 @@ class RoomCreationHandler(BaseHandler):
creation_content = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
}
} # type: JsonDict
# Check if old room was non-federatable

View file

@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
await self.store.add_account_data_for_user(
await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
await self.account_data_handler.add_tag_to_room(
user_id, new_room_id, tag, tag_content
)
async def update_membership(
self,

View file

@ -73,27 +73,45 @@ class SamlHandler(BaseHandler):
)
# identifier for the external_ids table
self._auth_provider_id = "saml"
self.idp_id = "saml"
# user-facing name of this auth provider
self.idp_name = "SAML"
# we do not currently support icons for SAML auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
async def handle_redirect_request(
self,
request: SynapseRequest,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
Args:
request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to when everything is done
client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
URL to redirect to
"""
if not client_redirect_url:
# Some SAML identity providers (e.g. Google) require a
# RelayState parameter on requests, so pass in a dummy redirect URL
# (which will never get used).
client_redirect_url = b"unused"
reqid, info = self._saml_client.prepare_for_authenticate(
entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
@ -210,7 +228,7 @@ class SamlHandler(BaseHandler):
return
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id,
self.idp_id,
remote_user_id,
current_session.ui_auth_session_id,
request,
@ -306,7 +324,7 @@ class SamlHandler(BaseHandler):
return None
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
self.idp_id,
remote_user_id,
request,
client_redirect_url,

View file

@ -12,15 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
from urllib.parse import urlencode
import attr
from typing_extensions import NoReturn
from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from synapse.api.errors import RedirectException, SynapseError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@ -40,6 +45,63 @@ class MappingException(Exception):
"""
class SsoIdentityProvider(Protocol):
"""Abstract base class to be implemented by SSO Identity Providers
An Identity Provider, or IdP, is an external HTTP service which authenticates a user
to say whether they should be allowed to log in, or perform a given action.
Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
and CAS.
The main entry point is `handle_redirect_request`, which should return a URI to
redirect the user's browser to the IdP's authentication page.
Each IdP should be registered with the SsoHandler via
`hs.get_sso_handler().register_identity_provider()`, so that requests to
`/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
"""
@property
@abc.abstractmethod
def idp_id(self) -> str:
"""A unique identifier for this SSO provider
Eg, "saml", "cas", "github"
"""
@property
@abc.abstractmethod
def idp_name(self) -> str:
"""User-facing name for this provider"""
@property
def idp_icon(self) -> Optional[str]:
"""Optional MXC URI for user-facing icon"""
return None
@abc.abstractmethod
async def handle_redirect_request(
self,
request: SynapseRequest,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
Args:
request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
URL to redirect to
"""
raise NotImplementedError()
@attr.s
class UserAttributes:
# the localpart of the mxid that the mapper has assigned to the user.
@ -91,8 +153,13 @@ class SsoHandler:
self._store = hs.get_datastore()
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler()
self._error_template = hs.config.sso_error_template
self._bad_user_template = hs.config.sso_auth_bad_user_template
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@ -100,6 +167,49 @@ class SsoHandler:
# a map from session id to session data
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
# map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
def register_identity_provider(self, p: SsoIdentityProvider):
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
"""Get the configured identity providers"""
return self._identity_providers
async def get_identity_providers_for_user(
self, user_id: str
) -> Mapping[str, SsoIdentityProvider]:
"""Get the SsoIdentityProviders which a user has used
Given a user id, get the identity providers that that user has used to log in
with in the past (and thus could use to re-identify themselves for UI Auth).
Args:
user_id: MXID of user to look up
Raises:
a map of idp_id to SsoIdentityProvider
"""
external_ids = await self._store.get_external_ids_by_user(user_id)
valid_idps = {}
for idp_id, _ in external_ids:
idp = self._identity_providers.get(idp_id)
if not idp:
logger.warning(
"User %r has an SSO mapping for IdP %r, but this is no longer "
"configured.",
user_id,
idp_id,
)
else:
valid_idps[idp_id] = idp
return valid_idps
def render_error(
self,
request: Request,
@ -124,6 +234,34 @@ class SsoHandler:
)
respond_with_html(request, code, html)
async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes,
) -> str:
"""Handle a request to /login/sso/redirect
Args:
request: incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login.
Returns:
the URI to redirect to
"""
if not self._identity_providers:
raise SynapseError(
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
)
# if we only have one auth provider, redirect to it directly
if len(self._identity_providers) == 1:
ap = next(iter(self._identity_providers.values()))
return await ap.handle_redirect_request(request, client_redirect_url)
# otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode(
(("redirectUrl", client_redirect_url),)
)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
) -> Optional[str]:
@ -268,7 +406,7 @@ class SsoHandler:
attributes,
auth_provider_id,
remote_user_id,
request.get_user_agent(""),
get_request_user_agent(request),
request.getClientIP(),
)
@ -451,19 +589,45 @@ class SsoHandler:
auth_provider_id, remote_user_id,
)
user_id_to_verify = await self._auth_handler.get_session_data(
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
if not user_id:
logger.warning(
"Remote user %s/%s has not previously logged in here: UIA will fail",
auth_provider_id,
remote_user_id,
)
# Let the UIA flow handle this the same as if they presented creds for a
# different user.
user_id = ""
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
elif user_id != user_id_to_verify:
logger.warning(
"Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
auth_provider_id,
remote_user_id,
user_id,
)
else:
# success!
# Mark the stage of the authentication as successful.
await self._store.mark_ui_auth_stage_complete(
ui_auth_session_id, LoginType.SSO, user_id
)
# Render the HTML confirmation page and return.
html = self._sso_auth_success_template
respond_with_html(request, 200, html)
return
# the user_id didn't match: mark the stage of the authentication as unsuccessful
await self._store.mark_ui_auth_stage_complete(
ui_auth_session_id, LoginType.SSO, ""
)
# render an error page.
html = self._bad_user_template.render(
server_name=self._server_name, user_id_to_verify=user_id_to_verify,
)
respond_with_html(request, 200, html)
async def check_username_availability(
self, localpart: str, session_id: str,
@ -534,7 +698,7 @@ class SsoHandler:
attributes,
session.auth_provider_id,
session.remote_user_id,
request.get_user_agent(""),
get_request_user_agent(request),
request.getClientIP(),
)

View file

@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
class UIAuthSessionDataConstants:
"""Constants for use with AuthHandler.set_session_data"""
# used during registration and password reset to store a hashed copy of the
# password, so that the client does not need to submit it each time.
PASSWORD_HASH = "password_hash"
# used during registration to store the mxid of the registered user
REGISTERED_USER_ID = "registered_user_id"
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"

View file

@ -17,6 +17,7 @@ import re
from twisted.internet import task
from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest
from synapse.api.errors import SynapseError
@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
pass
def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default.
"""
# There could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
h = request.getHeader(b"User-Agent")
return h.decode("ascii", "replace") if h else default

View file

@ -32,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
from netaddr import IPAddress, IPSet
from netaddr import AddrFormatError, IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@ -261,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
try:
ip_address = IPAddress(h.hostname)
except AddrFormatError:
# Not an IP
pass
else:
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
# Not an IP
pass
return self._agent.request(
method, uri, headers=headers, bodyProducer=bodyProducer
@ -725,7 +725,7 @@ class SimpleHttpClient:
read_body_with_max_size(response, output_stream, max_size)
)
except BodyExceededMaxSize:
SynapseError(
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
@ -767,14 +767,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
# If the deferred was called, bail early.
if self.deferred.called:
return
self.stream.write(data)
self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received
# in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):

View file

@ -1,79 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
logger = logging.getLogger(__name__)
def parse_server_name(server_name):
"""Split a server name into host/port parts.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
domain_port = server_name.rsplit(":", 1)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
return domain, port
except Exception:
raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name):
"""Split a server name into host/port parts and do some basic validation.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
host, port = parse_server_name(server_name)
# these tests don't need to be bulletproof as we'll find out soon enough
# if somebody is giving us invalid data. What we *do* need is to be sure
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError(
"Server name '%s' contains invalid characters" % (server_name,)
)
return host, port

View file

@ -102,7 +102,6 @@ class MatrixFederationAgent:
pool=self._pool,
contextFactory=tls_client_options_factory,
),
self._reactor,
ip_blacklist=ip_blacklist,
),
user_agent=self.user_agent,

View file

@ -174,6 +174,16 @@ async def _handle_json_response(
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = await make_deferred_yieldable(d)
except ValueError as e:
# The JSON content was invalid.
logger.warning(
"{%s} [%s] Failed to parse JSON response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=False) from e
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
@ -986,7 +996,7 @@ class MatrixFederationHttpClient:
logger.warning(
"{%s} [%s] %s", request.txn_id, request.destination, msg,
)
SynapseError(502, msg, Codes.TOO_LARGE)
raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",

View file

@ -20,7 +20,7 @@ from twisted.python.failure import Failure
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.types import Requester
@ -113,15 +113,6 @@ class SynapseRequest(Request):
method = self.method.decode("ascii")
return method
def get_user_agent(self, default: str) -> str:
"""Return the last User-Agent header, or the given default.
"""
user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
if user_agent is None:
return default
return user_agent.decode("ascii", "replace")
def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@ -292,12 +283,7 @@ class SynapseRequest(Request):
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
user_agent = self.get_user_agent("-")
user_agent = get_request_user_agent(self, "-")
code = str(self.code)
if not self.finished:

View file

@ -252,7 +252,12 @@ class LoggingContext:
"scope",
]
def __init__(self, name=None, parent_context=None, request=None) -> None:
def __init__(
self,
name: Optional[str] = None,
parent_context: "Optional[LoggingContext]" = None,
request: Optional[str] = None,
) -> None:
self.previous_context = current_context()
self.name = name
@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter):
def __init__(self, request: str = ""):
self._default_request = request
def filter(self, record) -> Literal[True]:
def filter(self, record: logging.LogRecord) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
"""
context = current_context()
record.request = self._default_request
record.request = self._default_request # type: ignore
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
# Logging is interested in the request.
record.request = context.request
record.request = context.request # type: ignore
return True
@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
return current
def nested_logging_context(
suffix: str, parent_context: Optional[LoggingContext] = None
) -> LoggingContext:
def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
@ -632,20 +635,23 @@ def nested_logging_context(
# ... do stuff
Args:
suffix (str): suffix to add to the parent context's 'request'.
parent_context (LoggingContext|None): parent context. Will use the current context
if None.
suffix: suffix to add to the parent context's 'request'.
Returns:
LoggingContext: new logging context.
"""
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = current_context()
return LoggingContext(
parent_context=context, request=str(context.request) + "-" + suffix
curr_context = current_context()
if not curr_context:
logger.warning(
"Starting nested logging context from sentinel context: metrics will be lost"
)
parent_context = None
prefix = ""
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
prefix = str(parent_context.request)
return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
def preserve_fn(f):
@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
logcontext = current_context()
curr_context = current_context()
if not curr_context:
logger.warning(
"Calling defer_to_threadpool from sentinel context: metrics will be lost"
)
parent_context = None
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
def g():
with LoggingContext(parent_context=logcontext):
with LoggingContext(parent_context=parent_context):
return f(*args, **kwargs)
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))

View file

@ -396,7 +396,6 @@ class Notifier:
Will wake up all listeners for the given users and rooms.
"""
with PreserveLoggingContext():
with Measure(self.clock, "on_new_event"):
user_streams = set()

View file

@ -203,13 +203,17 @@ class BulkPushRuleEvaluator:
condition_cache = {} # type: Dict[str, bool]
# If the event is not a state event check if any users ignore the sender.
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
ignorers = set()
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
if not event.is_state():
is_ignored = await self.store.is_ignored_by(event.sender, uid)
if is_ignored:
if uid in ignorers:
continue
display_name = None

View file

@ -15,6 +15,7 @@
from synapse.http.server import JsonResource
from synapse.replication.http import (
account_data,
devices,
federation,
login,
@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
streams.register_servlets(hs, self)
account_data.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:

View file

@ -177,7 +177,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
async def send_request(*, instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":

View file

@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
"""Add user account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_user_account_data/:user_id/:type
{
"content": { ... },
}
"""
NAME = "add_user_account_data"
PATH_ARGS = ("user_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
"""Add room account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
{
"content": { ... },
}
"""
NAME = "add_room_account_data"
PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationAddTagRestServlet(ReplicationEndpoint):
"""Add tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
{
"content": { ... },
}
"""
NAME = "add_tag"
PATH_ARGS = ("user_id", "room_id", "tag")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, tag):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
"""Remove tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
{}
"""
NAME = "remove_tag"
PATH_ARGS = (
"user_id",
"room_id",
"tag",
)
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag):
return {}
async def _handle_request(self, request, user_id, room_id, tag):
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs, http_server):
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
ReplicationRemoveTagRestServlet(hs).register(http_server)

View file

@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
tables=[
(
"cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq",
writers=[],
) # type: Optional[MultiWriterIdGenerator]

View file

@ -15,47 +15,9 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
)
super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
pass

View file

@ -14,46 +14,8 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token(),
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token(),
)
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
row.entity, token
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
pass

View file

@ -14,43 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
pass

View file

@ -51,11 +51,15 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
BackfillStream,
CachesStream,
EventsStream,
FederationStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
@ -115,6 +119,14 @@ class ReplicationCommandHandler:
continue
if isinstance(stream, ToDeviceStream):
# Only add ToDeviceStream as a source on instances in charge of
# sending to device messages.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, TypingStream):
# Only add TypingStream as a source on the instance in charge of
# typing.
@ -123,6 +135,22 @@ class ReplicationCommandHandler:
continue
if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, ReceiptsStream):
# Only add ReceiptsStream as a source on the instance in charge of
# receipts.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue

View file

@ -0,0 +1,18 @@
<html>
<head>
<title>Authentication Failed</title>
</head>
<body>
<div>
<p>
We were unable to validate your <tt>{{server_name | e}}</tt> account via
single-sign-on (SSO), because the SSO Identity Provider returned
different details than when you logged in.
</p>
<p>
Try the operation again, and ensure that you use the same details on
the Identity Provider as when you log into your account.
</p>
</div>
</body>
</html>

View file

@ -0,0 +1,31 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link rel="stylesheet" href="/_matrix/static/client/login/style.css">
<title>{{server_name | e}} Login</title>
</head>
<body>
<div id="container">
<h1 id="title">{{server_name | e}} Login</h1>
<div class="login_flow">
<p>Choose one of the following identity providers:</p>
<form>
<input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
<ul class="radiobuttons">
{% for p in providers %}
<li>
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
<label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
{% if p.idp_icon %}
<img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
{% endif %}
</li>
{% endfor %}
</ul>
<input type="submit" class="button button--full-width" id="button-submit" value="Submit">
</form>
</div>
</div>
</body>
</html>

View file

@ -15,6 +15,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, room_id: str):
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id: str):
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, server_name: str, media_id: str):
async def on_POST(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined.
"""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Protecting local media by ID: %s", media_id)
# Quarantine this media id
await self.store.mark_local_media_as_safe(media_id)
return 200, {}
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
async def on_POST(self, request):
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_DELETE(self, request, server_name: str, media_id: str):
async def on_DELETE(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_POST(self, request, server_name: str):
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
def register_servlets_for_media_repo(hs, http_server):
def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)

View file

@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet):
if deactivate and not user["deactivated"]:
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine(UserID.from_string(target_user_id)):
raise SynapseError(400, "Can only deactivate local users")
if not await self.store.get_user_by_id(target_user_id):
raise NotFoundError("User not found")
async def on_POST(self, request, target_user_id):
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
UserID.from_string(target_user_id)
result = await self._deactivate_account_handler.deactivate_account(
target_user_id, erase
target_user_id, erase, requester, by_admin=True
)
if result:
id_server_unbind_result = "success"
@ -714,13 +722,6 @@ class UserMembershipRestServlet(RestServlet):
async def on_GET(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
raise NotFoundError("Unknown user")
room_ids = await self.store.get_rooms_for_user(user_id)
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
return 200, ret

View file

@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
return result
class BaseSSORedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
if hs.config.cas_enabled:
hs.get_cas_handler()
if hs.config.saml2_enabled:
hs.get_saml_handler()
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
sso_url = await self.get_sso_url(request, client_redirect_url)
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
sso_url = await self._sso_handler.handle_redirect_request(
request, client_redirect_url
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
finish_request(request)
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
request: The client request to redirect.
client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url}
).encode("ascii")
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
)
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
class OIDCRedirectServlet(BaseSSORedirectServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler()
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return await self._oidc_handler.handle_redirect_request(
request, client_redirect_url
)
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
elif hs.config.oidc_enabled:
OIDCRedirectServlet(hs).register(http_server)

View file

@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING:
import synapse.server
@ -350,8 +350,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# provided.
if server:
raise e
else:
pass
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
@ -362,6 +360,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
)
try:
data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token
@ -405,6 +411,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
)
try:
data = await handler.get_remote_public_room_list(
server,

View file

@ -20,9 +20,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@ -31,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
requester, request, body, "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
user_id = requester.user.to_string()
@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
password_hash = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
else:
# UI validation was skipped, but the request did not include a new
@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
requester.user.to_string(), erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"deactivate your account",
requester, request, body, "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
requester.user.to_string(),
erase,
requester,
id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"add a third-party identifier to your account",
requester, request, body, "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(

View file

@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
max_id = await self.store.add_account_data_to_room(
await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._oidc_enabled = hs.config.oidc_enabled
if self._oidc_enabled:
self._oidc_handler = hs.get_oidc_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template
@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider.
if self._cas_enabled:
# Generate a request to CAS that redirects back to an endpoint
# to verify the successful authentication.
sso_redirect_url = self._cas_handler.get_redirect_url(
{"session": session},
)
elif self._saml_enabled:
# Some SAML identity providers (e.g. Google) require a
# RelayState parameter on requests. It is not necessary here, so
# pass in a dummy redirect URL (which will never get used).
client_redirect_url = b"unused"
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)
elif self._oidc_enabled:
client_redirect_url = b""
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
request, client_redirect_url, session
)
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(request, session)
else:
raise SynapseError(404, "Unknown auth stage type")
@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
LoginType.TERMS, authdict, request.getClientIP()
)
if success:

View file

@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"remove device(s) from your account",
requester, request, body, "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"remove a device from your account",
requester, request, body, "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)

View file

@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"add a device signing key to your account",
requester, request, body, "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)

View file

@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
ip = self.hs.get_ip_from_request(request)
ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
# Ensure that the username is valid.
@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
self._registration_flows, request, body, "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
session_id,
UIAuthSessionDataConstants.REGISTERED_USER_ID,
registered_user_id,
)
registered = True

View file

@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,10 +17,11 @@
import logging
import os
import urllib
from typing import Awaitable
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
]
def parse_media_id(request):
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
@ -69,7 +70,7 @@ def parse_media_id(request):
)
def respond_404(request):
def respond_404(request: Request) -> None:
respond_with_json(
request,
404,
@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None
):
request: Request,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
upload_name: Optional[str] = None,
) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name):
def add_file_headers(
request: Request,
media_type: str,
file_size: Optional[int],
upload_name: Optional[str],
) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request (twisted.web.http.Request)
media_type (str): The media/content type.
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
request
media_type: The media/content type.
file_size: Size in bytes of the media, if known.
upload_name: The name of the requested file, if any.
"""
def _quote(x):
@ -153,6 +163,7 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
}
def _can_encode_filename_as_token(x):
def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None
):
request: Request,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
upload_name: Optional[str] = None,
) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request (twisted.web.http.Request)
responder (Responder|None)
media_type (str): The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
request
responder
media_type: The media/content type.
file_size: Size in bytes of the media. If not known it should be None
upload_name: The name of the requested file, if any.
"""
if request._disconnected:
logger.warning(
@ -308,22 +323,22 @@ class FileInfo:
self.thumbnail_type = thumbnail_type
def get_filename_from_headers(headers):
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
headers (dict[bytes, list[bytes]]): The HTTP request headers.
headers: The HTTP request headers.
Returns:
A Unicode string of the filename, or None.
The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
return
return None
_, params = _parse_header(content_disposition[0])
@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
return upload_name
def _parse_header(line):
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
line (bytes): header to be parsed
line: header to be parsed
Returns:
Tuple[bytes, dict[bytes, bytes]]:
the main content-type, followed by the parameter dictionary
The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
@ -386,16 +400,16 @@ def _parse_header(line):
return key, pdict
def _parseparam(s):
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
s (bytes): header to be parsed
s: header to be parsed
Returns:
Iterable[bytes]: the split input
The split input
"""
while s[:1] == b";":
s = s[1:]

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,22 +15,29 @@
# limitations under the License.
#
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True
)
allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,11 +17,12 @@
import functools
import os
import re
from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
def _wrap_in_base_path(func):
def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured)
"""
def __init__(self, primary_base_path):
def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
def default_thumbnail_rel(
self, default_top_level, default_sub_type, width, height, content_type, method
):
self,
default_top_level: str,
default_sub_type: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id):
def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:],
)
def remote_media_filepath_rel(self, server_name, file_id):
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel(
self, server_name, file_id, width, height, content_type, method
):
self,
server_name: str,
file_id: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name, file_id, width, height, content_type
self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@ -126,7 +142,7 @@ class MediaFilePaths:
file_name,
)
def remote_media_thumbnail_dir(self, server_name, file_id):
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:],
)
def url_cache_filepath_rel(self, media_id):
def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id):
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
def url_cache_thumbnail_directory(self, media_id):
def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import logging
import os
import shutil
from typing import IO, Dict, List, Optional, Tuple
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
@ -73,16 +76,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed
)
async def _update_recently_accessed(self):
async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec()
)
def mark_recently_accessed(self, server_name, media_id):
def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed.
Args:
server_name (str|None): Origin server of media, or None if local
media_id (str): The media ID of the content
server_name: Origin server of media, or None if local
media_id: The media ID of the content
"""
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
def _generate_thumbnail(
self,
thumbnailer: Thumbnailer,
t_width: int,
t_height: int,
t_method: str,
t_type: str,
) -> Optional[BytesIO]:
m_width = thumbnailer.width
m_height = thumbnailer.height
@ -470,22 +480,20 @@ class MediaRepository:
m_height,
self.max_image_pixels,
)
return
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else:
t_byte_source = None
return thumbnailer.scale(t_width, t_height, t_type)
return t_byte_source
return None
async def generate_local_exact_thumbnail(
self,
@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height}
async def delete_old_remote_media(self, before_ts):
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ import os
import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest):
def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
open_file (file): A file like object to be streamed ot the client,
open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
def __init__(self, open_file):
def __init__(self, open_file: IO):
self.open_file = open_file
def write_to_consumer(self, consumer):
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import errno
import fnmatch
@ -23,12 +23,13 @@ import re
import shutil
import sys
import traceback
from typing import Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
if TYPE_CHECKING:
from lxml import etree
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@ -107,7 +115,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.auth = hs.get_auth()
@ -155,11 +168,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
@ -439,7 +452,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
async def _download_url(self, url: str, user):
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@ -569,7 +582,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data
)
async def _expire_url_cache_data(self):
async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@ -665,7 +678,9 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
# If there's no body, nothing useful is going to be found.
if not body:
return {}
@ -686,7 +701,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
return og
def _calc_og(tree, media_uri):
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
@ -790,7 +805,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og["og:description"] = summarize_paragraphs(text_nodes)
else:
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
@ -798,7 +815,9 @@ def _calc_og(tree, media_uri):
return og
def _iterate_over_text(tree, *tags_to_ignore):
def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
@ -832,32 +851,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith("/"):
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
return urlparse.urlunparse(url)
def _rebase_url(url: str, base: str) -> str:
base_parts = list(urlparse.urlparse(base))
url_parts = list(urlparse.urlparse(url))
if not url_parts[0]: # fix up schema
url_parts[0] = base_parts[0] or "http"
if not url_parts[1]: # fix up hostname
url_parts[1] = base_parts[1]
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url_parts)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_media(content_type: str) -> bool:
return content_type.lower().startswith("image/")
def _is_html(content_type):
def _is_html(content_type: str) -> bool:
content_type = content_type.lower()
if content_type.startswith("text/html") or content_type.startswith(
return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml"
):
return True
)
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
import shutil
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
@ -27,13 +28,17 @@ from .media_storage import FileResponder
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class StorageProvider:
class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
async def store_file(self, path: str, file_info: FileInfo):
@abc.abstractmethod
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file.
"""
@abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def __str__(self):
def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
async def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.store_file(path, file_info))
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
return None
async def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
hs (HomeServer)
hs
config: The config returned by `parse_config`.
"""
def __init__(self, hs, config):
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
return await defer_to_thread(
await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
async def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
return None
@staticmethod
def parse_config(config):
def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +16,14 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
@ -28,13 +33,22 @@ from ._base import (
respond_with_responder,
)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.store = hs.get_datastore()
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
self,
request: Request,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
):
request: Request,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
):
request: Request,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails(
@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
self,
request: Request,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
def _select_thumbnail(
self,
desired_width,
desired_height,
desired_method,
desired_type,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos,
):
) -> dict:
d_w = desired_width
d_h = desired_height

Some files were not shown because too many files have changed in this diff Show more