mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-17 13:00:17 -04:00
Merge remote-tracking branch 'upstream/release-v1.26.0'
This commit is contained in:
commit
d3547df958
183 changed files with 8528 additions and 2668 deletions
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
|
|
||||||
logger = logging.getLogger("create_postgres_db")
|
logger = logging.getLogger("create_postgres_db")
|
||||||
|
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -12,10 +12,12 @@
|
||||||
_trial_temp/
|
_trial_temp/
|
||||||
_trial_temp*/
|
_trial_temp*/
|
||||||
/out
|
/out
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
# stuff that is likely to exist when you run a server locally
|
# stuff that is likely to exist when you run a server locally
|
||||||
/*.db
|
/*.db
|
||||||
/*.log
|
/*.log
|
||||||
|
/*.log.*
|
||||||
/*.log.config
|
/*.log.config
|
||||||
/*.pid
|
/*.pid
|
||||||
/.python-version
|
/.python-version
|
||||||
|
|
87
CHANGES.md
87
CHANGES.md
|
@ -1,3 +1,90 @@
|
||||||
|
Synapse 1.26.0rc1 (2021-01-20)
|
||||||
|
==============================
|
||||||
|
|
||||||
|
This release brings a new schema version for Synapse and rolling back to a previous
|
||||||
|
verious is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
|
||||||
|
on these changes and for general upgrade guidance.
|
||||||
|
|
||||||
|
Features
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177))
|
||||||
|
- During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091))
|
||||||
|
- Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159))
|
||||||
|
- Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024))
|
||||||
|
- Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984))
|
||||||
|
- Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086))
|
||||||
|
- Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932))
|
||||||
|
- Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948))
|
||||||
|
- Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130))
|
||||||
|
- Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068))
|
||||||
|
- Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092))
|
||||||
|
- Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166))
|
||||||
|
|
||||||
|
|
||||||
|
Bugfixes
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023))
|
||||||
|
- Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028))
|
||||||
|
- Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051))
|
||||||
|
- Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053))
|
||||||
|
- Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054))
|
||||||
|
- Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059))
|
||||||
|
- Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070))
|
||||||
|
- Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071))
|
||||||
|
- Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116))
|
||||||
|
- Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117))
|
||||||
|
- Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128))
|
||||||
|
- Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108))
|
||||||
|
- Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145))
|
||||||
|
- Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161))
|
||||||
|
|
||||||
|
|
||||||
|
Improved Documentation
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
- Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997))
|
||||||
|
- Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035))
|
||||||
|
- Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040))
|
||||||
|
- Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057))
|
||||||
|
- Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151))
|
||||||
|
|
||||||
|
|
||||||
|
Deprecations and Removals
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
- Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039))
|
||||||
|
|
||||||
|
|
||||||
|
Internal Changes
|
||||||
|
----------------
|
||||||
|
|
||||||
|
- Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124))
|
||||||
|
- Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939))
|
||||||
|
- Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016))
|
||||||
|
- Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018))
|
||||||
|
- Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025))
|
||||||
|
- Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030))
|
||||||
|
- Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031))
|
||||||
|
- Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033))
|
||||||
|
- Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038))
|
||||||
|
- Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041))
|
||||||
|
- Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055))
|
||||||
|
- Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058))
|
||||||
|
- Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063))
|
||||||
|
- Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069))
|
||||||
|
- Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080))
|
||||||
|
- Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093))
|
||||||
|
- Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098))
|
||||||
|
- Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106))
|
||||||
|
- Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112))
|
||||||
|
- Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125))
|
||||||
|
- Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144))
|
||||||
|
- Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146))
|
||||||
|
- Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157))
|
||||||
|
|
||||||
|
|
||||||
Synapse 1.25.0 (2021-01-13)
|
Synapse 1.25.0 (2021-01-13)
|
||||||
===========================
|
===========================
|
||||||
|
|
||||||
|
|
|
@ -190,7 +190,8 @@ via brew and inform `pip` about it so that `psycopg2` builds:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
brew install openssl@1.1
|
brew install openssl@1.1
|
||||||
export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
|
export LDFLAGS="-L/usr/local/opt/openssl/lib"
|
||||||
|
export CPPFLAGS="-I/usr/local/opt/openssl/include"
|
||||||
```
|
```
|
||||||
|
|
||||||
##### OpenSUSE
|
##### OpenSUSE
|
||||||
|
@ -257,7 +258,7 @@ for a number of platforms.
|
||||||
|
|
||||||
#### Docker images and Ansible playbooks
|
#### Docker images and Ansible playbooks
|
||||||
|
|
||||||
There is an offical synapse image available at
|
There is an official synapse image available at
|
||||||
<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
|
<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
|
||||||
the docker-compose file available at [contrib/docker](contrib/docker). Further
|
the docker-compose file available at [contrib/docker](contrib/docker). Further
|
||||||
information on this including configuration options is available in the README
|
information on this including configuration options is available in the README
|
||||||
|
|
23
README.rst
23
README.rst
|
@ -243,7 +243,7 @@ Then update the ``users`` table in the database::
|
||||||
Synapse Development
|
Synapse Development
|
||||||
===================
|
===================
|
||||||
|
|
||||||
Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
|
Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
|
||||||
|
|
||||||
Before setting up a development environment for synapse, make sure you have the
|
Before setting up a development environment for synapse, make sure you have the
|
||||||
system dependencies (such as the python header files) installed - see
|
system dependencies (such as the python header files) installed - see
|
||||||
|
@ -280,6 +280,27 @@ differ)::
|
||||||
|
|
||||||
PASSED (skips=15, successes=1322)
|
PASSED (skips=15, successes=1322)
|
||||||
|
|
||||||
|
We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
|
||||||
|
|
||||||
|
./demo/start.sh
|
||||||
|
|
||||||
|
(to stop, you can use `./demo/stop.sh`)
|
||||||
|
|
||||||
|
If you just want to start a single instance of the app and run it directly:
|
||||||
|
|
||||||
|
# Create the homeserver.yaml config once
|
||||||
|
python -m synapse.app.homeserver \
|
||||||
|
--server-name my.domain.name \
|
||||||
|
--config-path homeserver.yaml \
|
||||||
|
--generate-config \
|
||||||
|
--report-stats=[yes|no]
|
||||||
|
|
||||||
|
# Start the app
|
||||||
|
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Running the Integration Tests
|
Running the Integration Tests
|
||||||
=============================
|
=============================
|
||||||
|
|
||||||
|
|
50
UPGRADE.rst
50
UPGRADE.rst
|
@ -85,6 +85,56 @@ for example:
|
||||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
|
|
||||||
|
Upgrading to v1.26.0
|
||||||
|
====================
|
||||||
|
|
||||||
|
Rolling back to v1.25.0 after a failed upgrade
|
||||||
|
----------------------------------------------
|
||||||
|
|
||||||
|
v1.26.0 includes a lot of large changes. If something problematic occurs, you
|
||||||
|
may want to roll-back to a previous version of Synapse. Because v1.26.0 also
|
||||||
|
includes a new database schema version, reverting that version is also required
|
||||||
|
alongside the generic rollback instructions mentioned above. In short, to roll
|
||||||
|
back to v1.25.0 you need to:
|
||||||
|
|
||||||
|
1. Stop the server
|
||||||
|
2. Decrease the schema version in the database:
|
||||||
|
|
||||||
|
.. code:: sql
|
||||||
|
|
||||||
|
UPDATE schema_version SET version = 58;
|
||||||
|
|
||||||
|
3. Delete the ignored users & chain cover data:
|
||||||
|
|
||||||
|
.. code:: sql
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS ignored_users;
|
||||||
|
UPDATE rooms SET has_auth_chain_index = false;
|
||||||
|
|
||||||
|
For PostgreSQL run:
|
||||||
|
|
||||||
|
.. code:: sql
|
||||||
|
|
||||||
|
TRUNCATE event_auth_chain_links;
|
||||||
|
TRUNCATE event_auth_chains;
|
||||||
|
|
||||||
|
For SQLite run:
|
||||||
|
|
||||||
|
.. code:: sql
|
||||||
|
|
||||||
|
DELETE FROM event_auth_chain_links;
|
||||||
|
DELETE FROM event_auth_chains;
|
||||||
|
|
||||||
|
4. Mark the deltas as not run (so they will re-run on upgrade).
|
||||||
|
|
||||||
|
.. code:: sql
|
||||||
|
|
||||||
|
DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py";
|
||||||
|
DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql";
|
||||||
|
|
||||||
|
5. Downgrade Synapse by following the instructions for your installation method
|
||||||
|
in the "Rolling back to older versions" section above.
|
||||||
|
|
||||||
Upgrading to v1.25.0
|
Upgrading to v1.25.0
|
||||||
====================
|
====================
|
||||||
|
|
||||||
|
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
|
||||||
|
|
||||||
|
* Remove dependency on `python3-distutils`.
|
||||||
|
|
||||||
|
-- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
|
||||||
|
|
||||||
matrix-synapse-py3 (1.25.0) stable; urgency=medium
|
matrix-synapse-py3 (1.25.0) stable; urgency=medium
|
||||||
|
|
||||||
[ Dan Callahan ]
|
[ Dan Callahan ]
|
||||||
|
|
1
debian/control
vendored
1
debian/control
vendored
|
@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
|
||||||
Depends:
|
Depends:
|
||||||
adduser,
|
adduser,
|
||||||
debconf,
|
debconf,
|
||||||
python3-distutils|libpython3-stdlib (<< 3.6),
|
|
||||||
${misc:Depends},
|
${misc:Depends},
|
||||||
${shlibs:Depends},
|
${shlibs:Depends},
|
||||||
${synapse:pydepends},
|
${synapse:pydepends},
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
import argparse
|
|
||||||
import BaseHTTPServer
|
|
||||||
import os
|
|
||||||
import SimpleHTTPServer
|
|
||||||
import cgi, logging
|
|
||||||
|
|
||||||
from daemonize import Daemonize
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
|
|
||||||
UPLOAD_PATH = "upload"
|
|
||||||
|
|
||||||
"""
|
|
||||||
Accept all post request as file upload
|
|
||||||
"""
|
|
||||||
|
|
||||||
def do_POST(self):
|
|
||||||
|
|
||||||
path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
|
|
||||||
length = self.headers["content-length"]
|
|
||||||
data = self.rfile.read(int(length))
|
|
||||||
|
|
||||||
with open(path, "wb") as fh:
|
|
||||||
fh.write(data)
|
|
||||||
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header("Content-Type", "application/json")
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
# Return the absolute path of the uploaded file
|
|
||||||
self.wfile.write('{"url":"/%s"}' % path)
|
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("directory")
|
|
||||||
parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
|
|
||||||
parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Get absolute path to directory to serve, as daemonize changes to '/'
|
|
||||||
os.chdir(args.directory)
|
|
||||||
dr = os.getcwd()
|
|
||||||
|
|
||||||
httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
|
|
||||||
|
|
||||||
def run():
|
|
||||||
os.chdir(dr)
|
|
||||||
httpd.serve_forever()
|
|
||||||
|
|
||||||
daemon = Daemonize(
|
|
||||||
app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
|
|
||||||
)
|
|
||||||
|
|
||||||
daemon.start()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
setup()
|
|
|
@ -198,12 +198,10 @@ old_signing_keys: {}
|
||||||
key_refresh_interval: "1d" # 1 Day.
|
key_refresh_interval: "1d" # 1 Day.
|
||||||
|
|
||||||
# The trusted servers to download signing keys from.
|
# The trusted servers to download signing keys from.
|
||||||
perspectives:
|
trusted_key_servers:
|
||||||
servers:
|
- server_name: matrix.org
|
||||||
"matrix.org":
|
|
||||||
verify_keys:
|
verify_keys:
|
||||||
"ed25519:auto":
|
"ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
||||||
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
|
||||||
|
|
||||||
password_config:
|
password_config:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
* [Quarantining media by ID](#quarantining-media-by-id)
|
* [Quarantining media by ID](#quarantining-media-by-id)
|
||||||
* [Quarantining media in a room](#quarantining-media-in-a-room)
|
* [Quarantining media in a room](#quarantining-media-in-a-room)
|
||||||
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
|
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
|
||||||
|
* [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
|
||||||
- [Delete local media](#delete-local-media)
|
- [Delete local media](#delete-local-media)
|
||||||
* [Delete a specific local media](#delete-a-specific-local-media)
|
* [Delete a specific local media](#delete-a-specific-local-media)
|
||||||
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
|
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
|
||||||
|
@ -123,6 +124,29 @@ The following fields are returned in the JSON response body:
|
||||||
|
|
||||||
* `num_quarantined`: integer - The number of media items successfully quarantined
|
* `num_quarantined`: integer - The number of media items successfully quarantined
|
||||||
|
|
||||||
|
## Protecting media from being quarantined
|
||||||
|
|
||||||
|
This API protects a single piece of local media from being quarantined using the
|
||||||
|
above APIs. This is useful for sticker packs and other shared media which you do
|
||||||
|
not want to get quarantined, especially when
|
||||||
|
[quarantining media in a room](#quarantining-media-in-a-room).
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /_synapse/admin/v1/media/protect/<media_id>
|
||||||
|
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
|
Where `media_id` is in the form of `abcdefg12345...`.
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
# Delete local media
|
# Delete local media
|
||||||
This API deletes the *local* media from the disk of your own server.
|
This API deletes the *local* media from the disk of your own server.
|
||||||
This includes any local thumbnails and copies of media downloaded from
|
This includes any local thumbnails and copies of media downloaded from
|
||||||
|
|
|
@ -98,6 +98,8 @@ Body parameters:
|
||||||
|
|
||||||
- ``deactivated``, optional. If unspecified, deactivation state will be left
|
- ``deactivated``, optional. If unspecified, deactivation state will be left
|
||||||
unchanged on existing accounts and set to ``false`` for new accounts.
|
unchanged on existing accounts and set to ``false`` for new accounts.
|
||||||
|
A user cannot be erased by deactivating with this API. For details on deactivating users see
|
||||||
|
`Deactivate Account <#deactivate-account>`_.
|
||||||
|
|
||||||
If the user already exists then optional parameters default to the current value.
|
If the user already exists then optional parameters default to the current value.
|
||||||
|
|
||||||
|
@ -248,6 +250,25 @@ server admin: see `README.rst <README.rst>`_.
|
||||||
The erase parameter is optional and defaults to ``false``.
|
The erase parameter is optional and defaults to ``false``.
|
||||||
An empty body may be passed for backwards compatibility.
|
An empty body may be passed for backwards compatibility.
|
||||||
|
|
||||||
|
The following actions are performed when deactivating an user:
|
||||||
|
|
||||||
|
- Try to unpind 3PIDs from the identity server
|
||||||
|
- Remove all 3PIDs from the homeserver
|
||||||
|
- Delete all devices and E2EE keys
|
||||||
|
- Delete all access tokens
|
||||||
|
- Delete the password hash
|
||||||
|
- Removal from all rooms the user is a member of
|
||||||
|
- Remove the user from the user directory
|
||||||
|
- Reject all pending invites
|
||||||
|
- Remove all account validity information related to the user
|
||||||
|
|
||||||
|
The following additional actions are performed during deactivation if``erase``
|
||||||
|
is set to ``true``:
|
||||||
|
|
||||||
|
- Remove the user's display name
|
||||||
|
- Remove the user's avatar URL
|
||||||
|
- Mark the user as erased
|
||||||
|
|
||||||
|
|
||||||
Reset password
|
Reset password
|
||||||
==============
|
==============
|
||||||
|
@ -337,6 +358,10 @@ A response body like the following is returned:
|
||||||
"total": 2
|
"total": 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
The server returns the list of rooms of which the user and the server
|
||||||
|
are member. If the user is local, all the rooms of which the user is
|
||||||
|
member are returned.
|
||||||
|
|
||||||
**Parameters**
|
**Parameters**
|
||||||
|
|
||||||
The following parameters should be set in the URL:
|
The following parameters should be set in the URL:
|
||||||
|
|
32
docs/auth_chain_diff.dot
Normal file
32
docs/auth_chain_diff.dot
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
digraph auth {
|
||||||
|
nodesep=0.5;
|
||||||
|
rankdir="RL";
|
||||||
|
|
||||||
|
C [label="Create (1,1)"];
|
||||||
|
|
||||||
|
BJ [label="Bob's Join (2,1)", color=red];
|
||||||
|
BJ2 [label="Bob's Join (2,2)", color=red];
|
||||||
|
BJ2 -> BJ [color=red, dir=none];
|
||||||
|
|
||||||
|
subgraph cluster_foo {
|
||||||
|
A1 [label="Alice's invite (4,1)", color=blue];
|
||||||
|
A2 [label="Alice's Join (4,2)", color=blue];
|
||||||
|
A3 [label="Alice's Join (4,3)", color=blue];
|
||||||
|
A3 -> A2 -> A1 [color=blue, dir=none];
|
||||||
|
color=none;
|
||||||
|
}
|
||||||
|
|
||||||
|
PL1 [label="Power Level (3,1)", color=darkgreen];
|
||||||
|
PL2 [label="Power Level (3,2)", color=darkgreen];
|
||||||
|
PL2 -> PL1 [color=darkgreen, dir=none];
|
||||||
|
|
||||||
|
{rank = same; C; BJ; PL1; A1;}
|
||||||
|
|
||||||
|
A1 -> C [color=grey];
|
||||||
|
A1 -> BJ [color=grey];
|
||||||
|
PL1 -> C [color=grey];
|
||||||
|
BJ2 -> PL1 [penwidth=2];
|
||||||
|
|
||||||
|
A3 -> PL2 [penwidth=2];
|
||||||
|
A1 -> PL1 -> BJ -> C [penwidth=2];
|
||||||
|
}
|
BIN
docs/auth_chain_diff.dot.png
Normal file
BIN
docs/auth_chain_diff.dot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
108
docs/auth_chain_difference_algorithm.md
Normal file
108
docs/auth_chain_difference_algorithm.md
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
# Auth Chain Difference Algorithm
|
||||||
|
|
||||||
|
The auth chain difference algorithm is used by V2 state resolution, where a
|
||||||
|
naive implementation can be a significant source of CPU and DB usage.
|
||||||
|
|
||||||
|
### Definitions
|
||||||
|
|
||||||
|
A *state set* is a set of state events; e.g. the input of a state resolution
|
||||||
|
algorithm is a collection of state sets.
|
||||||
|
|
||||||
|
The *auth chain* of a set of events are all the events' auth events and *their*
|
||||||
|
auth events, recursively (i.e. the events reachable by walking the graph induced
|
||||||
|
by an event's auth events links).
|
||||||
|
|
||||||
|
The *auth chain difference* of a collection of state sets is the union minus the
|
||||||
|
intersection of the sets of auth chains corresponding to the state sets, i.e an
|
||||||
|
event is in the auth chain difference if it is reachable by walking the auth
|
||||||
|
event graph from at least one of the state sets but not from *all* of the state
|
||||||
|
sets.
|
||||||
|
|
||||||
|
## Breadth First Walk Algorithm
|
||||||
|
|
||||||
|
A way of calculating the auth chain difference without calculating the full auth
|
||||||
|
chains for each state set is to do a parallel breadth first walk (ordered by
|
||||||
|
depth) of each state set's auth chain. By tracking which events are reachable
|
||||||
|
from each state set we can finish early if every pending event is reachable from
|
||||||
|
every state set.
|
||||||
|
|
||||||
|
This can work well for state sets that have a small auth chain difference, but
|
||||||
|
can be very inefficient for larger differences. However, this algorithm is still
|
||||||
|
used if we don't have a chain cover index for the room (e.g. because we're in
|
||||||
|
the process of indexing it).
|
||||||
|
|
||||||
|
## Chain Cover Index
|
||||||
|
|
||||||
|
Synapse computes auth chain differences by pre-computing a "chain cover" index
|
||||||
|
for the auth chain in a room, allowing efficient reachability queries like "is
|
||||||
|
event A in the auth chain of event B". This is done by assigning every event a
|
||||||
|
*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
|
||||||
|
between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
|
||||||
|
is in the auth chain of `B`) if and only if either:
|
||||||
|
|
||||||
|
1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
|
||||||
|
sequence number; or
|
||||||
|
2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
|
||||||
|
`L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
|
||||||
|
|
||||||
|
There are actually two potential implementations, one where we store links from
|
||||||
|
each chain to every other reachable chain (the transitive closure of the links
|
||||||
|
graph), and one where we remove redundant links (the transitive reduction of the
|
||||||
|
links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
|
||||||
|
would not be stored. Synapse uses the former implementations so that it doesn't
|
||||||
|
need to recurse to test reachability between chains.
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
An example auth graph would look like the following, where chains have been
|
||||||
|
formed based on type/state_key and are denoted by colour and are labelled with
|
||||||
|
`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
|
||||||
|
are those that would be remove in the second implementation described above).
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Note that we don't include all links between events and their auth events, as
|
||||||
|
most of those links would be redundant. For example, all events point to the
|
||||||
|
create event, but each chain only needs the one link from it's base to the
|
||||||
|
create event.
|
||||||
|
|
||||||
|
## Using the Index
|
||||||
|
|
||||||
|
This index can be used to calculate the auth chain difference of the state sets
|
||||||
|
by looking at the chain ID and sequence numbers reachable from each state set:
|
||||||
|
|
||||||
|
1. For every state set lookup the chain ID/sequence numbers of each state event
|
||||||
|
2. Use the index to find all chains and the maximum sequence number reachable
|
||||||
|
from each state set.
|
||||||
|
3. The auth chain difference is then all events in each chain that have sequence
|
||||||
|
numbers between the maximum sequence number reachable from *any* state set and
|
||||||
|
the minimum reachable by *all* state sets (if any).
|
||||||
|
|
||||||
|
Note that steps 2 is effectively calculating the auth chain for each state set
|
||||||
|
(in terms of chain IDs and sequence numbers), and step 3 is calculating the
|
||||||
|
difference between the union and intersection of the auth chains.
|
||||||
|
|
||||||
|
### Worked Example
|
||||||
|
|
||||||
|
For example, given the above graph, we can calculate the difference between
|
||||||
|
state sets consisting of:
|
||||||
|
|
||||||
|
1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
|
||||||
|
2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
|
||||||
|
|
||||||
|
Using the index we see that the following auth chains are reachable from each
|
||||||
|
state set:
|
||||||
|
|
||||||
|
1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
|
||||||
|
2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
|
||||||
|
|
||||||
|
And so, for each the ranges that are in the auth chain difference:
|
||||||
|
1. Chain 1: None, (since everything can reach the create event).
|
||||||
|
2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
|
||||||
|
sets and the maximum reachable is `2` (corresponding to Bob's second join).
|
||||||
|
3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
|
||||||
|
level).
|
||||||
|
4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
|
||||||
|
|
||||||
|
So the final result is: Bob's second join `(2,2)`, the second power level
|
||||||
|
`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.
|
|
@ -42,11 +42,10 @@ as follows:
|
||||||
* For other installation mechanisms, see the documentation provided by the
|
* For other installation mechanisms, see the documentation provided by the
|
||||||
maintainer.
|
maintainer.
|
||||||
|
|
||||||
To enable the OpenID integration, you should then add an `oidc_config` section
|
To enable the OpenID integration, you should then add a section to the `oidc_providers`
|
||||||
to your configuration file (or uncomment the `enabled: true` line in the
|
setting in your configuration file (or uncomment one of the existing examples).
|
||||||
existing section). See [sample_config.yaml](./sample_config.yaml) for some
|
See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
|
||||||
sample settings, as well as the text below for example configurations for
|
the text below for example configurations for specific providers.
|
||||||
specific providers.
|
|
||||||
|
|
||||||
## Sample configs
|
## Sample configs
|
||||||
|
|
||||||
|
@ -62,8 +61,9 @@ Directory (tenant) ID as it will be used in the Azure links.
|
||||||
Edit your Synapse config file and change the `oidc_config` section:
|
Edit your Synapse config file and change the `oidc_config` section:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: microsoft
|
||||||
|
idp_name: Microsoft
|
||||||
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
|
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
|
||||||
client_id: "<client id>"
|
client_id: "<client id>"
|
||||||
client_secret: "<client secret>"
|
client_secret: "<client secret>"
|
||||||
|
@ -103,8 +103,9 @@ Run with `dex serve examples/config-dev.yaml`.
|
||||||
Synapse config:
|
Synapse config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: dex
|
||||||
|
idp_name: "My Dex server"
|
||||||
skip_verification: true # This is needed as Dex is served on an insecure endpoint
|
skip_verification: true # This is needed as Dex is served on an insecure endpoint
|
||||||
issuer: "http://127.0.0.1:5556/dex"
|
issuer: "http://127.0.0.1:5556/dex"
|
||||||
client_id: "synapse"
|
client_id: "synapse"
|
||||||
|
@ -152,12 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
|
||||||
8. Copy Secret
|
8. Copy Secret
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: keycloak
|
||||||
|
idp_name: "My KeyCloak server"
|
||||||
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
|
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
|
||||||
client_id: "synapse"
|
client_id: "synapse"
|
||||||
client_secret: "copy secret generated from above"
|
client_secret: "copy secret generated from above"
|
||||||
scopes: ["openid", "profile"]
|
scopes: ["openid", "profile"]
|
||||||
|
user_mapping_provider:
|
||||||
|
config:
|
||||||
|
localpart_template: "{{ user.preferred_username }}"
|
||||||
|
display_name_template: "{{ user.name }}"
|
||||||
```
|
```
|
||||||
### [Auth0][auth0]
|
### [Auth0][auth0]
|
||||||
|
|
||||||
|
@ -187,8 +193,9 @@ oidc_config:
|
||||||
Synapse config:
|
Synapse config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: auth0
|
||||||
|
idp_name: Auth0
|
||||||
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
|
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
@ -215,8 +222,9 @@ does not return a `sub` property, an alternative `subject_claim` has to be set.
|
||||||
Synapse config:
|
Synapse config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: github
|
||||||
|
idp_name: Github
|
||||||
discover: false
|
discover: false
|
||||||
issuer: "https://github.com/"
|
issuer: "https://github.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
|
@ -239,8 +247,9 @@ oidc_config:
|
||||||
2. add an "OAuth Client ID" for a Web Application under "Credentials".
|
2. add an "OAuth Client ID" for a Web Application under "Credentials".
|
||||||
3. Copy the Client ID and Client Secret, and add the following to your synapse config:
|
3. Copy the Client ID and Client Secret, and add the following to your synapse config:
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: google
|
||||||
|
idp_name: Google
|
||||||
issuer: "https://accounts.google.com/"
|
issuer: "https://accounts.google.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
@ -262,8 +271,9 @@ oidc_config:
|
||||||
Synapse config:
|
Synapse config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: twitch
|
||||||
|
idp_name: Twitch
|
||||||
issuer: "https://id.twitch.tv/oauth2/"
|
issuer: "https://id.twitch.tv/oauth2/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
@ -283,8 +293,9 @@ oidc_config:
|
||||||
Synapse config:
|
Synapse config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
oidc_config:
|
oidc_providers:
|
||||||
enabled: true
|
- idp_id: gitlab
|
||||||
|
idp_name: Gitlab
|
||||||
issuer: "https://gitlab.com/"
|
issuer: "https://gitlab.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
|
|
@ -18,7 +18,7 @@ connect to a postgres database.
|
||||||
virtualenv](../INSTALL.md#installing-from-source), you can install
|
virtualenv](../INSTALL.md#installing-from-source), you can install
|
||||||
the library with:
|
the library with:
|
||||||
|
|
||||||
~/synapse/env/bin/pip install matrix-synapse[postgres]
|
~/synapse/env/bin/pip install "matrix-synapse[postgres]"
|
||||||
|
|
||||||
(substituting the path to your virtualenv for `~/synapse/env`, if
|
(substituting the path to your virtualenv for `~/synapse/env`, if
|
||||||
you used a different path). You will require the postgres
|
you used a different path). You will require the postgres
|
||||||
|
|
|
@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid
|
||||||
#
|
#
|
||||||
#web_client_location: https://riot.example.com/
|
#web_client_location: https://riot.example.com/
|
||||||
|
|
||||||
# The public-facing base URL that clients use to access this HS
|
# The public-facing base URL that clients use to access this Homeserver (not
|
||||||
# (not including _matrix/...). This is the same URL a user would
|
# including _matrix/...). This is the same URL a user might enter into the
|
||||||
# enter into the 'custom HS URL' field on their client. If you
|
# 'Custom Homeserver URL' field on their client. If you use Synapse with a
|
||||||
# use synapse with a reverse proxy, this should be the URL to reach
|
# reverse proxy, this should be the URL to reach Synapse via the proxy.
|
||||||
# synapse via the proxy.
|
# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
|
||||||
|
# 'listeners' below).
|
||||||
|
#
|
||||||
|
# If this is left unset, it defaults to 'https://<server_name>/'. (Note that
|
||||||
|
# that will not work unless you configure Synapse or a reverse-proxy to listen
|
||||||
|
# on port 443.)
|
||||||
#
|
#
|
||||||
#public_baseurl: https://example.com/
|
#public_baseurl: https://example.com/
|
||||||
|
|
||||||
|
@ -1150,8 +1155,9 @@ account_validity:
|
||||||
# send an email to the account's email address with a renewal link. By
|
# send an email to the account's email address with a renewal link. By
|
||||||
# default, no such emails are sent.
|
# default, no such emails are sent.
|
||||||
#
|
#
|
||||||
# If you enable this setting, you will also need to fill out the 'email' and
|
# If you enable this setting, you will also need to fill out the 'email'
|
||||||
# 'public_baseurl' configuration sections.
|
# configuration section. You should also check that 'public_baseurl' is set
|
||||||
|
# correctly.
|
||||||
#
|
#
|
||||||
#renew_at: 1w
|
#renew_at: 1w
|
||||||
|
|
||||||
|
@ -1242,8 +1248,7 @@ account_validity:
|
||||||
# The identity server which we suggest that clients should use when users log
|
# The identity server which we suggest that clients should use when users log
|
||||||
# in on this server.
|
# in on this server.
|
||||||
#
|
#
|
||||||
# (By default, no suggestion is made, so it is left up to the client.
|
# (By default, no suggestion is made, so it is left up to the client.)
|
||||||
# This setting is ignored unless public_baseurl is also set.)
|
|
||||||
#
|
#
|
||||||
#default_identity_server: https://matrix.org
|
#default_identity_server: https://matrix.org
|
||||||
|
|
||||||
|
@ -1268,8 +1273,6 @@ account_validity:
|
||||||
# by the Matrix Identity Service API specification:
|
# by the Matrix Identity Service API specification:
|
||||||
# https://matrix.org/docs/spec/identity_service/latest
|
# https://matrix.org/docs/spec/identity_service/latest
|
||||||
#
|
#
|
||||||
# If a delegate is specified, the config option public_baseurl must also be filled out.
|
|
||||||
#
|
|
||||||
account_threepid_delegates:
|
account_threepid_delegates:
|
||||||
#email: https://example.com # Delegate email sending to example.com
|
#email: https://example.com # Delegate email sending to example.com
|
||||||
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
||||||
|
@ -1709,141 +1712,153 @@ saml2_config:
|
||||||
#idp_entityid: 'https://our_idp/entityid'
|
#idp_entityid: 'https://our_idp/entityid'
|
||||||
|
|
||||||
|
|
||||||
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
|
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
|
||||||
|
# and login.
|
||||||
|
#
|
||||||
|
# Options for each entry include:
|
||||||
|
#
|
||||||
|
# idp_id: a unique identifier for this identity provider. Used internally
|
||||||
|
# by Synapse; should be a single word such as 'github'.
|
||||||
|
#
|
||||||
|
# Note that, if this is changed, users authenticating via that provider
|
||||||
|
# will no longer be recognised as the same user!
|
||||||
|
#
|
||||||
|
# idp_name: A user-facing name for this identity provider, which is used to
|
||||||
|
# offer the user a choice of login mechanisms.
|
||||||
|
#
|
||||||
|
# idp_icon: An optional icon for this identity provider, which is presented
|
||||||
|
# by identity picker pages. If given, must be an MXC URI of the format
|
||||||
|
# mxc://<server-name>/<media-id>
|
||||||
|
#
|
||||||
|
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||||
|
# to discover endpoints. Defaults to true.
|
||||||
|
#
|
||||||
|
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
|
||||||
|
# is enabled) to discover the provider's endpoints.
|
||||||
|
#
|
||||||
|
# client_id: Required. oauth2 client id to use.
|
||||||
|
#
|
||||||
|
# client_secret: Required. oauth2 client secret to use.
|
||||||
|
#
|
||||||
|
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||||
|
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||||
|
# 'none'.
|
||||||
|
#
|
||||||
|
# scopes: list of scopes to request. This should normally include the "openid"
|
||||||
|
# scope. Defaults to ["openid"].
|
||||||
|
#
|
||||||
|
# authorization_endpoint: the oauth2 authorization endpoint. Required if
|
||||||
|
# provider discovery is disabled.
|
||||||
|
#
|
||||||
|
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
|
||||||
|
# disabled.
|
||||||
|
#
|
||||||
|
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
|
||||||
|
# disabled and the 'openid' scope is not requested.
|
||||||
|
#
|
||||||
|
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
|
||||||
|
# the 'openid' scope is used.
|
||||||
|
#
|
||||||
|
# skip_verification: set to 'true' to skip metadata verification. Use this if
|
||||||
|
# you are connecting to a provider that is not OpenID Connect compliant.
|
||||||
|
# Defaults to false. Avoid this in production.
|
||||||
|
#
|
||||||
|
# user_profile_method: Whether to fetch the user profile from the userinfo
|
||||||
|
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
|
||||||
|
#
|
||||||
|
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
|
||||||
|
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
|
||||||
|
# userinfo endpoint.
|
||||||
|
#
|
||||||
|
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
|
||||||
|
# match a pre-existing account instead of failing. This could be used if
|
||||||
|
# switching from password logins to OIDC. Defaults to false.
|
||||||
|
#
|
||||||
|
# user_mapping_provider: Configuration for how attributes returned from a OIDC
|
||||||
|
# provider are mapped onto a matrix user. This setting has the following
|
||||||
|
# sub-properties:
|
||||||
|
#
|
||||||
|
# module: The class name of a custom mapping module. Default is
|
||||||
|
# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
|
||||||
|
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
||||||
|
# for information on implementing a custom mapping provider.
|
||||||
|
#
|
||||||
|
# config: Configuration for the mapping provider module. This section will
|
||||||
|
# be passed as a Python dictionary to the user mapping provider
|
||||||
|
# module's `parse_config` method.
|
||||||
|
#
|
||||||
|
# For the default provider, the following settings are available:
|
||||||
|
#
|
||||||
|
# sub: name of the claim containing a unique identifier for the
|
||||||
|
# user. Defaults to 'sub', which OpenID Connect compliant
|
||||||
|
# providers should provide.
|
||||||
|
#
|
||||||
|
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||||
|
# If this is not set, the user will be prompted to choose their
|
||||||
|
# own username.
|
||||||
|
#
|
||||||
|
# display_name_template: Jinja2 template for the display name to set
|
||||||
|
# on first login. If unset, no displayname will be set.
|
||||||
|
#
|
||||||
|
# extra_attributes: a map of Jinja2 templates for extra attributes
|
||||||
|
# to send back to the client during login.
|
||||||
|
# Note that these are non-standard and clients will ignore them
|
||||||
|
# without modifications.
|
||||||
|
#
|
||||||
|
# When rendering, the Jinja2 templates are given a 'user' variable,
|
||||||
|
# which is set to the claims returned by the UserInfo Endpoint and/or
|
||||||
|
# in the ID Token.
|
||||||
#
|
#
|
||||||
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
|
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
|
||||||
# for some example configurations.
|
# for information on how to configure these options.
|
||||||
#
|
#
|
||||||
oidc_config:
|
# For backwards compatibility, it is also possible to configure a single OIDC
|
||||||
# Uncomment the following to enable authorization against an OpenID Connect
|
# provider via an 'oidc_config' setting. This is now deprecated and admins are
|
||||||
# server. Defaults to false.
|
# advised to migrate to the 'oidc_providers' format.
|
||||||
|
#
|
||||||
|
oidc_providers:
|
||||||
|
# Generic example
|
||||||
#
|
#
|
||||||
#enabled: true
|
#- idp_id: my_idp
|
||||||
|
# idp_name: "My OpenID provider"
|
||||||
|
# discover: false
|
||||||
|
# issuer: "https://accounts.example.com/"
|
||||||
|
# client_id: "provided-by-your-issuer"
|
||||||
|
# client_secret: "provided-by-your-issuer"
|
||||||
|
# client_auth_method: client_secret_post
|
||||||
|
# scopes: ["openid", "profile"]
|
||||||
|
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||||
|
# token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||||
|
# userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||||
|
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||||
|
# skip_verification: true
|
||||||
|
|
||||||
# Uncomment the following to disable use of the OIDC discovery mechanism to
|
# For use with Keycloak
|
||||||
# discover endpoints. Defaults to true.
|
|
||||||
#
|
#
|
||||||
#discover: false
|
#- idp_id: keycloak
|
||||||
|
# idp_name: Keycloak
|
||||||
|
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
|
||||||
|
# client_id: "synapse"
|
||||||
|
# client_secret: "copy secret generated in Keycloak UI"
|
||||||
|
# scopes: ["openid", "profile"]
|
||||||
|
|
||||||
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
# For use with Github
|
||||||
# discover the provider's endpoints.
|
|
||||||
#
|
#
|
||||||
# Required if 'enabled' is true.
|
#- idp_id: google
|
||||||
#
|
# idp_name: Google
|
||||||
#issuer: "https://accounts.example.com/"
|
# discover: false
|
||||||
|
# issuer: "https://github.com/"
|
||||||
# oauth2 client id to use.
|
# client_id: "your-client-id" # TO BE FILLED
|
||||||
#
|
# client_secret: "your-client-secret" # TO BE FILLED
|
||||||
# Required if 'enabled' is true.
|
# authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||||
#
|
# token_endpoint: "https://github.com/login/oauth/access_token"
|
||||||
#client_id: "provided-by-your-issuer"
|
# userinfo_endpoint: "https://api.github.com/user"
|
||||||
|
# scopes: ["read:user"]
|
||||||
# oauth2 client secret to use.
|
# user_mapping_provider:
|
||||||
#
|
# config:
|
||||||
# Required if 'enabled' is true.
|
# subject_claim: "id"
|
||||||
#
|
# localpart_template: "{ user.login }"
|
||||||
#client_secret: "provided-by-your-issuer"
|
# display_name_template: "{ user.name }"
|
||||||
|
|
||||||
# auth method to use when exchanging the token.
|
|
||||||
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
|
|
||||||
# 'none'.
|
|
||||||
#
|
|
||||||
#client_auth_method: client_secret_post
|
|
||||||
|
|
||||||
# list of scopes to request. This should normally include the "openid" scope.
|
|
||||||
# Defaults to ["openid"].
|
|
||||||
#
|
|
||||||
#scopes: ["openid", "profile"]
|
|
||||||
|
|
||||||
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
|
|
||||||
#
|
|
||||||
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
|
||||||
|
|
||||||
# the oauth2 token endpoint. Required if provider discovery is disabled.
|
|
||||||
#
|
|
||||||
#token_endpoint: "https://accounts.example.com/oauth2/token"
|
|
||||||
|
|
||||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
|
||||||
# "openid" scope is not requested.
|
|
||||||
#
|
|
||||||
#userinfo_endpoint: "https://accounts.example.com/userinfo"
|
|
||||||
|
|
||||||
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
|
||||||
# "openid" scope is used.
|
|
||||||
#
|
|
||||||
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
|
||||||
|
|
||||||
# Uncomment to skip metadata verification. Defaults to false.
|
|
||||||
#
|
|
||||||
# Use this if you are connecting to a provider that is not OpenID Connect
|
|
||||||
# compliant.
|
|
||||||
# Avoid this in production.
|
|
||||||
#
|
|
||||||
#skip_verification: true
|
|
||||||
|
|
||||||
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
|
||||||
# values are: "auto" or "userinfo_endpoint".
|
|
||||||
#
|
|
||||||
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
|
|
||||||
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
|
|
||||||
#
|
|
||||||
#user_profile_method: "userinfo_endpoint"
|
|
||||||
|
|
||||||
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
|
|
||||||
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
|
|
||||||
#
|
|
||||||
#allow_existing_users: true
|
|
||||||
|
|
||||||
# An external module can be provided here as a custom solution to mapping
|
|
||||||
# attributes returned from a OIDC provider onto a matrix user.
|
|
||||||
#
|
|
||||||
user_mapping_provider:
|
|
||||||
# The custom module's class. Uncomment to use a custom module.
|
|
||||||
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
|
|
||||||
#
|
|
||||||
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
|
||||||
# for information on implementing a custom mapping provider.
|
|
||||||
#
|
|
||||||
#module: mapping_provider.OidcMappingProvider
|
|
||||||
|
|
||||||
# Custom configuration values for the module. This section will be passed as
|
|
||||||
# a Python dictionary to the user mapping provider module's `parse_config`
|
|
||||||
# method.
|
|
||||||
#
|
|
||||||
# The examples below are intended for the default provider: they should be
|
|
||||||
# changed if using a custom provider.
|
|
||||||
#
|
|
||||||
config:
|
|
||||||
# name of the claim containing a unique identifier for the user.
|
|
||||||
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
|
|
||||||
#
|
|
||||||
#subject_claim: "sub"
|
|
||||||
|
|
||||||
# Jinja2 template for the localpart of the MXID.
|
|
||||||
#
|
|
||||||
# When rendering, this template is given the following variables:
|
|
||||||
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
|
|
||||||
# Token
|
|
||||||
#
|
|
||||||
# If this is not set, the user will be prompted to choose their
|
|
||||||
# own username.
|
|
||||||
#
|
|
||||||
#localpart_template: "{{ user.preferred_username }}"
|
|
||||||
|
|
||||||
# Jinja2 template for the display name to set on first login.
|
|
||||||
#
|
|
||||||
# If unset, no displayname will be set.
|
|
||||||
#
|
|
||||||
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
|
|
||||||
|
|
||||||
# Jinja2 templates for extra attributes to send back to the client during
|
|
||||||
# login.
|
|
||||||
#
|
|
||||||
# Note that these are non-standard and clients will ignore them without modifications.
|
|
||||||
#
|
|
||||||
#extra_attributes:
|
|
||||||
#birthdate: "{{ user.birthdate }}"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Enable Central Authentication Service (CAS) for registration and login.
|
# Enable Central Authentication Service (CAS) for registration and login.
|
||||||
|
@ -1893,9 +1908,9 @@ sso:
|
||||||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||||
# hostname: "https://my.client/".
|
# hostname: "https://my.client/".
|
||||||
#
|
#
|
||||||
# If public_baseurl is set, then the login fallback page (used by clients
|
# The login fallback page (used by clients that don't natively support the
|
||||||
# that don't natively support the required login flows) is whitelisted in
|
# required login flows) is automatically whitelisted in addition to any URLs
|
||||||
# addition to any URLs in this list.
|
# in this list.
|
||||||
#
|
#
|
||||||
# By default, this list is empty.
|
# By default, this list is empty.
|
||||||
#
|
#
|
||||||
|
@ -1909,6 +1924,31 @@ sso:
|
||||||
#
|
#
|
||||||
# Synapse will look for the following templates in this directory:
|
# Synapse will look for the following templates in this directory:
|
||||||
#
|
#
|
||||||
|
# * HTML page to prompt the user to choose an Identity Provider during
|
||||||
|
# login: 'sso_login_idp_picker.html'.
|
||||||
|
#
|
||||||
|
# This is only used if multiple SSO Identity Providers are configured.
|
||||||
|
#
|
||||||
|
# When rendering, this template is given the following variables:
|
||||||
|
# * redirect_url: the URL that the user will be redirected to after
|
||||||
|
# login. Needs manual escaping (see
|
||||||
|
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||||
|
#
|
||||||
|
# * server_name: the homeserver's name.
|
||||||
|
#
|
||||||
|
# * providers: a list of available Identity Providers. Each element is
|
||||||
|
# an object with the following attributes:
|
||||||
|
# * idp_id: unique identifier for the IdP
|
||||||
|
# * idp_name: user-facing name for the IdP
|
||||||
|
#
|
||||||
|
# The rendered HTML page should contain a form which submits its results
|
||||||
|
# back as a GET request, with the following query parameters:
|
||||||
|
#
|
||||||
|
# * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
|
||||||
|
# to the template)
|
||||||
|
#
|
||||||
|
# * idp: the 'idp_id' of the chosen IDP.
|
||||||
|
#
|
||||||
# * HTML page for a confirmation step before redirecting back to the client
|
# * HTML page for a confirmation step before redirecting back to the client
|
||||||
# with the login token: 'sso_redirect_confirm.html'.
|
# with the login token: 'sso_redirect_confirm.html'.
|
||||||
#
|
#
|
||||||
|
@ -1944,6 +1984,14 @@ sso:
|
||||||
#
|
#
|
||||||
# This template has no additional variables.
|
# This template has no additional variables.
|
||||||
#
|
#
|
||||||
|
# * HTML page shown after a user-interactive authentication session which
|
||||||
|
# does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
|
||||||
|
#
|
||||||
|
# When rendering, this template is given the following variables:
|
||||||
|
# * server_name: the homeserver's name.
|
||||||
|
# * user_id_to_verify: the MXID of the user that we are trying to
|
||||||
|
# validate.
|
||||||
|
#
|
||||||
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
||||||
# attempts to login: 'sso_account_deactivated.html'.
|
# attempts to login: 'sso_account_deactivated.html'.
|
||||||
#
|
#
|
||||||
|
|
|
@ -31,7 +31,7 @@ There is no need for a separate configuration file for the master process.
|
||||||
1. Adjust synapse configuration files as above.
|
1. Adjust synapse configuration files as above.
|
||||||
1. Copy the `*.service` and `*.target` files in [system](system) to
|
1. Copy the `*.service` and `*.target` files in [system](system) to
|
||||||
`/etc/systemd/system`.
|
`/etc/systemd/system`.
|
||||||
1. Run `systemctl deamon-reload` to tell systemd to load the new unit files.
|
1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
|
||||||
1. Run `systemctl enable matrix-synapse.service`. This will configure the
|
1. Run `systemctl enable matrix-synapse.service`. This will configure the
|
||||||
synapse master process to be started as part of the `matrix-synapse.target`
|
synapse master process to be started as part of the `matrix-synapse.target`
|
||||||
target.
|
target.
|
||||||
|
|
|
@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
|
||||||
be used for demo purposes and any admin considering workers should already be
|
be used for demo purposes and any admin considering workers should already be
|
||||||
running PostgreSQL.
|
running PostgreSQL.
|
||||||
|
|
||||||
|
See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
|
||||||
|
for a higher level overview.
|
||||||
|
|
||||||
## Main process/worker communication
|
## Main process/worker communication
|
||||||
|
|
||||||
The processes communicate with each other via a Synapse-specific protocol called
|
The processes communicate with each other via a Synapse-specific protocol called
|
||||||
|
@ -56,7 +59,7 @@ The appropriate dependencies must also be installed for Synapse. If using a
|
||||||
virtualenv, these can be installed with:
|
virtualenv, these can be installed with:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
pip install matrix-synapse[redis]
|
pip install "matrix-synapse[redis]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that these dependencies are included when synapse is installed with `pip
|
Note that these dependencies are included when synapse is installed with `pip
|
||||||
|
@ -214,6 +217,7 @@ expressions:
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
|
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/devices$
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
|
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
|
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
|
||||||
^/_matrix/client/versions$
|
^/_matrix/client/versions$
|
||||||
|
|
2
mypy.ini
2
mypy.ini
|
@ -100,9 +100,11 @@ files =
|
||||||
synapse/util/async_helpers.py,
|
synapse/util/async_helpers.py,
|
||||||
synapse/util/caches,
|
synapse/util/caches,
|
||||||
synapse/util/metrics.py,
|
synapse/util/metrics.py,
|
||||||
|
synapse/util/stringutils.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
tests/test_utils,
|
tests/test_utils,
|
||||||
tests/handlers/test_password_providers.py,
|
tests/handlers/test_password_providers.py,
|
||||||
|
tests/rest/client/v1/test_login.py,
|
||||||
tests/rest/client/v2_alpha/test_auth.py,
|
tests/rest/client/v2_alpha/test_auth.py,
|
||||||
tests/util/test_stream_change_cache.py
|
tests/util/test_stream_change_cache.py
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
|
||||||
|
|
||||||
BOOLEAN_COLUMNS = {
|
BOOLEAN_COLUMNS = {
|
||||||
"events": ["processed", "outlier", "contains_url"],
|
"events": ["processed", "outlier", "contains_url"],
|
||||||
"rooms": ["is_public"],
|
"rooms": ["is_public", "has_auth_chain_index"],
|
||||||
"event_edges": ["is_state"],
|
"event_edges": ["is_state"],
|
||||||
"presence_list": ["accepted"],
|
"presence_list": ["accepted"],
|
||||||
"presence_stream": ["currently_active"],
|
"presence_stream": ["currently_active"],
|
||||||
|
@ -629,6 +629,7 @@ class Porter(object):
|
||||||
await self._setup_state_group_id_seq()
|
await self._setup_state_group_id_seq()
|
||||||
await self._setup_user_id_seq()
|
await self._setup_user_id_seq()
|
||||||
await self._setup_events_stream_seqs()
|
await self._setup_events_stream_seqs()
|
||||||
|
await self._setup_device_inbox_seq()
|
||||||
|
|
||||||
# Step 3. Get tables.
|
# Step 3. Get tables.
|
||||||
self.progress.set_state("Fetching tables")
|
self.progress.set_state("Fetching tables")
|
||||||
|
@ -911,6 +912,32 @@ class Porter(object):
|
||||||
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
|
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _setup_device_inbox_seq(self):
|
||||||
|
"""Set the device inbox sequence to the correct value.
|
||||||
|
"""
|
||||||
|
curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
|
table="device_inbox",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="COALESCE(MAX(stream_id), 1)",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
|
table="device_federation_outbox",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="COALESCE(MAX(stream_id), 1)",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_id = max(curr_local_id, curr_federation_id) + 1
|
||||||
|
|
||||||
|
def r(txn):
|
||||||
|
txn.execute(
|
||||||
|
"ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
# The following is simply UI stuff
|
# The following is simply UI stuff
|
||||||
|
|
|
@ -15,16 +15,7 @@
|
||||||
|
|
||||||
# Stub for frozendict.
|
# Stub for frozendict.
|
||||||
|
|
||||||
from typing import (
|
from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
|
||||||
Any,
|
|
||||||
Hashable,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Mapping,
|
|
||||||
overload,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
_KT = TypeVar("_KT", bound=Hashable) # Key type.
|
_KT = TypeVar("_KT", bound=Hashable) # Key type.
|
||||||
_VT = TypeVar("_VT") # Value type.
|
_VT = TypeVar("_VT") # Value type.
|
||||||
|
|
|
@ -7,17 +7,17 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Hashable,
|
Hashable,
|
||||||
Iterator,
|
|
||||||
Iterable,
|
|
||||||
ItemsView,
|
ItemsView,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
KeysView,
|
KeysView,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Tuple,
|
|
||||||
Union,
|
Union,
|
||||||
ValuesView,
|
ValuesView,
|
||||||
overload,
|
overload,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"""Contains *incomplete* type hints for txredisapi.
|
"""Contains *incomplete* type hints for txredisapi.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union, Type
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
class RedisProtocol:
|
class RedisProtocol:
|
||||||
def publish(self, channel: str, message: bytes): ...
|
def publish(self, channel: str, message: bytes): ...
|
||||||
|
|
|
@ -48,7 +48,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.25.0"
|
__version__ = "1.26.0rc1"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.http import get_request_user_agent
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging import opentracing as opentracing
|
from synapse.logging import opentracing as opentracing
|
||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
|
@ -186,8 +187,8 @@ class Auth:
|
||||||
AuthError if access is denied for the user in the access token
|
AuthError if access is denied for the user in the access token
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ip_addr = self.hs.get_ip_from_request(request)
|
ip_addr = request.getClientIP()
|
||||||
user_agent = request.get_user_agent("")
|
user_agent = get_request_user_agent(request)
|
||||||
|
|
||||||
access_token = self.get_access_token_from_request(request)
|
access_token = self.get_access_token_from_request(request)
|
||||||
|
|
||||||
|
@ -275,7 +276,7 @@ class Auth:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if app_service.ip_range_whitelist:
|
if app_service.ip_range_whitelist:
|
||||||
ip_address = IPAddress(self.hs.get_ip_from_request(request))
|
ip_address = IPAddress(request.getClientIP())
|
||||||
if ip_address not in app_service.ip_range_whitelist:
|
if ip_address not in app_service.ip_range_whitelist:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
|
@ -51,11 +51,11 @@ class RoomDisposition:
|
||||||
class RoomVersion:
|
class RoomVersion:
|
||||||
"""An object which describes the unique attributes of a room version."""
|
"""An object which describes the unique attributes of a room version."""
|
||||||
|
|
||||||
identifier = attr.ib() # str; the identifier for this version
|
identifier = attr.ib(type=str) # the identifier for this version
|
||||||
disposition = attr.ib() # str; one of the RoomDispositions
|
disposition = attr.ib(type=str) # one of the RoomDispositions
|
||||||
event_format = attr.ib() # int; one of the EventFormatVersions
|
event_format = attr.ib(type=int) # one of the EventFormatVersions
|
||||||
state_res = attr.ib() # int; one of the StateResolutionVersions
|
state_res = attr.ib(type=int) # one of the StateResolutionVersions
|
||||||
enforce_key_validity = attr.ib() # bool
|
enforce_key_validity = attr.ib(type=bool)
|
||||||
|
|
||||||
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
|
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
|
||||||
special_case_aliases_auth = attr.ib(type=bool)
|
special_case_aliases_auth = attr.ib(type=bool)
|
||||||
|
@ -64,9 +64,11 @@ class RoomVersion:
|
||||||
# * Floats
|
# * Floats
|
||||||
# * NaN, Infinity, -Infinity
|
# * NaN, Infinity, -Infinity
|
||||||
strict_canonicaljson = attr.ib(type=bool)
|
strict_canonicaljson = attr.ib(type=bool)
|
||||||
# bool: MSC2209: Check 'notifications' key while verifying
|
# MSC2209: Check 'notifications' key while verifying
|
||||||
# m.room.power_levels auth rules.
|
# m.room.power_levels auth rules.
|
||||||
limit_notifications_power_levels = attr.ib(type=bool)
|
limit_notifications_power_levels = attr.ib(type=bool)
|
||||||
|
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
|
||||||
|
msc2176_redaction_rules = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
||||||
class RoomVersions:
|
class RoomVersions:
|
||||||
|
@ -79,6 +81,7 @@ class RoomVersions:
|
||||||
special_case_aliases_auth=True,
|
special_case_aliases_auth=True,
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
)
|
)
|
||||||
V2 = RoomVersion(
|
V2 = RoomVersion(
|
||||||
"2",
|
"2",
|
||||||
|
@ -89,6 +92,7 @@ class RoomVersions:
|
||||||
special_case_aliases_auth=True,
|
special_case_aliases_auth=True,
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
)
|
)
|
||||||
V3 = RoomVersion(
|
V3 = RoomVersion(
|
||||||
"3",
|
"3",
|
||||||
|
@ -99,6 +103,7 @@ class RoomVersions:
|
||||||
special_case_aliases_auth=True,
|
special_case_aliases_auth=True,
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
)
|
)
|
||||||
V4 = RoomVersion(
|
V4 = RoomVersion(
|
||||||
"4",
|
"4",
|
||||||
|
@ -109,6 +114,7 @@ class RoomVersions:
|
||||||
special_case_aliases_auth=True,
|
special_case_aliases_auth=True,
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
)
|
)
|
||||||
V5 = RoomVersion(
|
V5 = RoomVersion(
|
||||||
"5",
|
"5",
|
||||||
|
@ -119,6 +125,7 @@ class RoomVersions:
|
||||||
special_case_aliases_auth=True,
|
special_case_aliases_auth=True,
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
)
|
)
|
||||||
V6 = RoomVersion(
|
V6 = RoomVersion(
|
||||||
"6",
|
"6",
|
||||||
|
@ -129,6 +136,18 @@ class RoomVersions:
|
||||||
special_case_aliases_auth=False,
|
special_case_aliases_auth=False,
|
||||||
strict_canonicaljson=True,
|
strict_canonicaljson=True,
|
||||||
limit_notifications_power_levels=True,
|
limit_notifications_power_levels=True,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
|
)
|
||||||
|
MSC2176 = RoomVersion(
|
||||||
|
"org.matrix.msc2176",
|
||||||
|
RoomDisposition.UNSTABLE,
|
||||||
|
EventFormatVersions.V3,
|
||||||
|
StateResolutionVersions.V2,
|
||||||
|
enforce_key_validity=True,
|
||||||
|
special_case_aliases_auth=False,
|
||||||
|
strict_canonicaljson=True,
|
||||||
|
limit_notifications_power_levels=True,
|
||||||
|
msc2176_redaction_rules=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = {
|
||||||
RoomVersions.V4,
|
RoomVersions.V4,
|
||||||
RoomVersions.V5,
|
RoomVersions.V5,
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
|
RoomVersions.MSC2176,
|
||||||
)
|
)
|
||||||
} # type: Dict[str, RoomVersion]
|
} # type: Dict[str, RoomVersion]
|
||||||
|
|
|
@ -42,8 +42,6 @@ class ConsentURIBuilder:
|
||||||
"""
|
"""
|
||||||
if hs_config.form_secret is None:
|
if hs_config.form_secret is None:
|
||||||
raise ConfigError("form_secret not set in config")
|
raise ConfigError("form_secret not set in config")
|
||||||
if hs_config.public_baseurl is None:
|
|
||||||
raise ConfigError("public_baseurl not set in config")
|
|
||||||
|
|
||||||
self._hmac_secret = hs_config.form_secret.encode("utf-8")
|
self._hmac_secret = hs_config.form_secret.encode("utf-8")
|
||||||
self._public_baseurl = hs_config.public_baseurl
|
self._public_baseurl = hs_config.public_baseurl
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2017 New Vector Ltd
|
# Copyright 2017 New Vector Ltd
|
||||||
|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,7 +20,7 @@ import signal
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Iterable
|
from typing import Awaitable, Callable, Iterable
|
||||||
|
|
||||||
from typing_extensions import NoReturn
|
from typing_extensions import NoReturn
|
||||||
|
|
||||||
|
@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
|
||||||
|
"""Register a callback with the reactor, to be called once it is running
|
||||||
|
|
||||||
|
This can be used to initialise parts of the system which require an asynchronous
|
||||||
|
setup.
|
||||||
|
|
||||||
|
Any exception raised by the callback will be printed and logged, and the process
|
||||||
|
will exit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper():
|
||||||
|
try:
|
||||||
|
await cb(*args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
# previously, we used Failure().printTraceback() here, in the hope that
|
||||||
|
# would give better tracebacks than traceback.print_exc(). However, that
|
||||||
|
# doesn't handle chained exceptions (with a __cause__ or __context__) well,
|
||||||
|
# and I *think* the need for Failure() is reduced now that we mostly use
|
||||||
|
# async/await.
|
||||||
|
|
||||||
|
# Write the exception to both the logs *and* the unredirected stderr,
|
||||||
|
# because people tend to get confused if it only goes to one or the other.
|
||||||
|
#
|
||||||
|
# One problem with this is that if people are using a logging config that
|
||||||
|
# logs to the console (as is common eg under docker), they will get two
|
||||||
|
# copies of the exception. We could maybe try to detect that, but it's
|
||||||
|
# probably a cost we can bear.
|
||||||
|
logger.fatal("Error during startup", exc_info=True)
|
||||||
|
print("Error during startup:", file=sys.__stderr__)
|
||||||
|
traceback.print_exc(file=sys.__stderr__)
|
||||||
|
|
||||||
|
# it's no use calling sys.exit here, since that just raises a SystemExit
|
||||||
|
# exception which is then caught by the reactor, and everything carries
|
||||||
|
# on as normal.
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
|
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||||
|
|
||||||
|
|
||||||
def listen_metrics(bind_addresses, port):
|
def listen_metrics(bind_addresses, port):
|
||||||
"""
|
"""
|
||||||
Start Prometheus metrics server.
|
Start Prometheus metrics server.
|
||||||
|
@ -227,7 +267,7 @@ def refresh_certificate(hs):
|
||||||
logger.info("Context factories updated.")
|
logger.info("Context factories updated.")
|
||||||
|
|
||||||
|
|
||||||
def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||||
"""
|
"""
|
||||||
Start a Synapse server or worker.
|
Start a Synapse server or worker.
|
||||||
|
|
||||||
|
@ -241,10 +281,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||||
hs: homeserver instance
|
hs: homeserver instance
|
||||||
listeners: Listener configuration ('listeners' in homeserver.yaml)
|
listeners: Listener configuration ('listeners' in homeserver.yaml)
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# Set up the SIGHUP machinery.
|
# Set up the SIGHUP machinery.
|
||||||
if hasattr(signal, "SIGHUP"):
|
if hasattr(signal, "SIGHUP"):
|
||||||
|
|
||||||
reactor = hs.get_reactor()
|
reactor = hs.get_reactor()
|
||||||
|
|
||||||
@wrap_as_background_process("sighup")
|
@wrap_as_background_process("sighup")
|
||||||
|
@ -304,12 +342,6 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||||
if sys.version_info >= (3, 7):
|
if sys.version_info >= (3, 7):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
gc.freeze()
|
gc.freeze()
|
||||||
except Exception:
|
|
||||||
traceback.print_exc(file=sys.stderr)
|
|
||||||
reactor = hs.get_reactor()
|
|
||||||
if reactor.running:
|
|
||||||
reactor.stop()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_sentry(hs):
|
def setup_sentry(hs):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
|
||||||
|
|
||||||
from typing_extensions import ContextManager
|
from typing_extensions import ContextManager
|
||||||
|
|
||||||
from twisted.internet import address, reactor
|
from twisted.internet import address
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.events
|
import synapse.events
|
||||||
|
@ -34,6 +34,7 @@ from synapse.api.urls import (
|
||||||
SERVER_KEY_V2_PREFIX,
|
SERVER_KEY_V2_PREFIX,
|
||||||
)
|
)
|
||||||
from synapse.app import _base
|
from synapse.app import _base
|
||||||
|
from synapse.app._base import register_start
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.logger import setup_logging
|
from synapse.config.logger import setup_logging
|
||||||
|
@ -99,21 +100,37 @@ from synapse.rest.client.v1.profile import (
|
||||||
)
|
)
|
||||||
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
|
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
|
||||||
from synapse.rest.client.v1.voip import VoipRestServlet
|
from synapse.rest.client.v1.voip import VoipRestServlet
|
||||||
from synapse.rest.client.v2_alpha import groups, sync, user_directory
|
from synapse.rest.client.v2_alpha import (
|
||||||
|
account_data,
|
||||||
|
groups,
|
||||||
|
read_marker,
|
||||||
|
receipts,
|
||||||
|
room_keys,
|
||||||
|
sync,
|
||||||
|
tags,
|
||||||
|
user_directory,
|
||||||
|
)
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
|
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
|
||||||
from synapse.rest.client.v2_alpha.account_data import (
|
from synapse.rest.client.v2_alpha.account_data import (
|
||||||
AccountDataServlet,
|
AccountDataServlet,
|
||||||
RoomAccountDataServlet,
|
RoomAccountDataServlet,
|
||||||
)
|
)
|
||||||
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
|
from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
|
||||||
|
from synapse.rest.client.v2_alpha.keys import (
|
||||||
|
KeyChangesServlet,
|
||||||
|
KeyQueryServlet,
|
||||||
|
OneTimeKeyServlet,
|
||||||
|
)
|
||||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
|
from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
|
||||||
from synapse.rest.client.versions import VersionsRestServlet
|
from synapse.rest.client.versions import VersionsRestServlet
|
||||||
from synapse.rest.health import HealthResource
|
from synapse.rest.health import HealthResource
|
||||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||||
from synapse.server import HomeServer, cache_in_self
|
from synapse.server import HomeServer, cache_in_self
|
||||||
from synapse.storage.databases.main.censor_events import CensorEventsStore
|
from synapse.storage.databases.main.censor_events import CensorEventsStore
|
||||||
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
|
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
|
||||||
|
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
|
||||||
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
|
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
|
||||||
from synapse.storage.databases.main.metrics import ServerMetricsStore
|
from synapse.storage.databases.main.metrics import ServerMetricsStore
|
||||||
from synapse.storage.databases.main.monthly_active_users import (
|
from synapse.storage.databases.main.monthly_active_users import (
|
||||||
|
@ -445,6 +462,7 @@ class GenericWorkerSlavedStore(
|
||||||
UserDirectoryStore,
|
UserDirectoryStore,
|
||||||
StatsStore,
|
StatsStore,
|
||||||
UIAuthWorkerStore,
|
UIAuthWorkerStore,
|
||||||
|
EndToEndRoomKeyStore,
|
||||||
SlavedDeviceInboxStore,
|
SlavedDeviceInboxStore,
|
||||||
SlavedDeviceStore,
|
SlavedDeviceStore,
|
||||||
SlavedReceiptsStore,
|
SlavedReceiptsStore,
|
||||||
|
@ -501,7 +519,9 @@ class GenericWorkerServer(HomeServer):
|
||||||
RegisterRestServlet(self).register(resource)
|
RegisterRestServlet(self).register(resource)
|
||||||
LoginRestServlet(self).register(resource)
|
LoginRestServlet(self).register(resource)
|
||||||
ThreepidRestServlet(self).register(resource)
|
ThreepidRestServlet(self).register(resource)
|
||||||
|
DevicesRestServlet(self).register(resource)
|
||||||
KeyQueryServlet(self).register(resource)
|
KeyQueryServlet(self).register(resource)
|
||||||
|
OneTimeKeyServlet(self).register(resource)
|
||||||
KeyChangesServlet(self).register(resource)
|
KeyChangesServlet(self).register(resource)
|
||||||
VoipRestServlet(self).register(resource)
|
VoipRestServlet(self).register(resource)
|
||||||
PushRuleRestServlet(self).register(resource)
|
PushRuleRestServlet(self).register(resource)
|
||||||
|
@ -519,6 +539,13 @@ class GenericWorkerServer(HomeServer):
|
||||||
room.register_servlets(self, resource, True)
|
room.register_servlets(self, resource, True)
|
||||||
room.register_deprecated_servlets(self, resource)
|
room.register_deprecated_servlets(self, resource)
|
||||||
InitialSyncRestServlet(self).register(resource)
|
InitialSyncRestServlet(self).register(resource)
|
||||||
|
room_keys.register_servlets(self, resource)
|
||||||
|
tags.register_servlets(self, resource)
|
||||||
|
account_data.register_servlets(self, resource)
|
||||||
|
receipts.register_servlets(self, resource)
|
||||||
|
read_marker.register_servlets(self, resource)
|
||||||
|
|
||||||
|
SendToDeviceRestServlet(self).register(resource)
|
||||||
|
|
||||||
user_directory.register_servlets(self, resource)
|
user_directory.register_servlets(self, resource)
|
||||||
|
|
||||||
|
@ -957,9 +984,7 @@ def start(config_options):
|
||||||
# streams. Will no-op if no streams can be written to by this worker.
|
# streams. Will no-op if no streams can be written to by this worker.
|
||||||
hs.get_replication_streamer()
|
hs.get_replication_streamer()
|
||||||
|
|
||||||
reactor.addSystemEventTrigger(
|
register_start(_base.start, hs, config.worker_listeners)
|
||||||
"before", "startup", _base.start, hs, config.worker_listeners
|
|
||||||
)
|
|
||||||
|
|
||||||
_base.start_worker_reactor("synapse-generic-worker", config)
|
_base.start_worker_reactor("synapse-generic-worker", config)
|
||||||
|
|
||||||
|
|
|
@ -15,15 +15,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Iterable, Iterator
|
from typing import Iterable, Iterator
|
||||||
|
|
||||||
from twisted.application import service
|
from twisted.internet import reactor
|
||||||
from twisted.internet import defer, reactor
|
|
||||||
from twisted.python.failure import Failure
|
|
||||||
from twisted.web.resource import EncodingResourceWrapper, IResource
|
from twisted.web.resource import EncodingResourceWrapper, IResource
|
||||||
from twisted.web.server import GzipEncoderFactory
|
from twisted.web.server import GzipEncoderFactory
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
|
@ -40,7 +37,7 @@ from synapse.api.urls import (
|
||||||
WEB_CLIENT_PREFIX,
|
WEB_CLIENT_PREFIX,
|
||||||
)
|
)
|
||||||
from synapse.app import _base
|
from synapse.app import _base
|
||||||
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
|
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.config.emailconfig import ThreepidBehaviour
|
from synapse.config.emailconfig import ThreepidBehaviour
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -63,6 +60,7 @@ from synapse.rest import ClientRestResource
|
||||||
from synapse.rest.admin import AdminRestResource
|
from synapse.rest.admin import AdminRestResource
|
||||||
from synapse.rest.health import HealthResource
|
from synapse.rest.health import HealthResource
|
||||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||||
|
from synapse.rest.synapse.client.pick_idp import PickIdpResource
|
||||||
from synapse.rest.synapse.client.pick_username import pick_username_resource
|
from synapse.rest.synapse.client.pick_username import pick_username_resource
|
||||||
from synapse.rest.well_known import WellKnownResource
|
from synapse.rest.well_known import WellKnownResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -72,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
from synapse.util.rlimit import change_resource_limit
|
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.homeserver")
|
logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
@ -194,6 +191,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
"/.well-known/matrix/client": WellKnownResource(self),
|
"/.well-known/matrix/client": WellKnownResource(self),
|
||||||
"/_synapse/admin": AdminRestResource(self),
|
"/_synapse/admin": AdminRestResource(self),
|
||||||
"/_synapse/client/pick_username": pick_username_resource(self),
|
"/_synapse/client/pick_username": pick_username_resource(self),
|
||||||
|
"/_synapse/client/pick_idp": PickIdpResource(self),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -415,7 +413,6 @@ def setup(config_options):
|
||||||
_base.refresh_certificate(hs)
|
_base.refresh_certificate(hs)
|
||||||
|
|
||||||
async def start():
|
async def start():
|
||||||
try:
|
|
||||||
# Run the ACME provisioning code, if it's enabled.
|
# Run the ACME provisioning code, if it's enabled.
|
||||||
if hs.config.acme_enabled:
|
if hs.config.acme_enabled:
|
||||||
acme = hs.get_acme_handler()
|
acme = hs.get_acme_handler()
|
||||||
|
@ -432,23 +429,12 @@ def setup(config_options):
|
||||||
oidc = hs.get_oidc_handler()
|
oidc = hs.get_oidc_handler()
|
||||||
# Loading the provider metadata also ensures the provider config is valid.
|
# Loading the provider metadata also ensures the provider config is valid.
|
||||||
await oidc.load_metadata()
|
await oidc.load_metadata()
|
||||||
await oidc.load_jwks()
|
|
||||||
|
|
||||||
_base.start(hs, config.listeners)
|
await _base.start(hs, config.listeners)
|
||||||
|
|
||||||
hs.get_datastore().db_pool.updates.start_doing_background_updates()
|
hs.get_datastore().db_pool.updates.start_doing_background_updates()
|
||||||
except Exception:
|
|
||||||
# Print the exception and bail out.
|
|
||||||
print("Error during startup:", file=sys.stderr)
|
|
||||||
|
|
||||||
# this gives better tracebacks than traceback.print_exc()
|
register_start(start)
|
||||||
Failure().printTraceback(file=sys.stderr)
|
|
||||||
|
|
||||||
if reactor.running:
|
|
||||||
reactor.stop()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
|
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -485,25 +471,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
|
||||||
e = e.__cause__
|
e = e.__cause__
|
||||||
|
|
||||||
|
|
||||||
class SynapseService(service.Service):
|
|
||||||
"""
|
|
||||||
A twisted Service class that will start synapse. Used to run synapse
|
|
||||||
via twistd and a .tac.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def startService(self):
|
|
||||||
hs = setup(self.config)
|
|
||||||
change_resource_limit(hs.config.soft_file_limit)
|
|
||||||
if hs.config.gc_thresholds:
|
|
||||||
gc.set_threshold(*hs.config.gc_thresholds)
|
|
||||||
|
|
||||||
def stopService(self):
|
|
||||||
return self._port.stopListening()
|
|
||||||
|
|
||||||
|
|
||||||
def run(hs):
|
def run(hs):
|
||||||
PROFILE_SYNAPSE = False
|
PROFILE_SYNAPSE = False
|
||||||
if PROFILE_SYNAPSE:
|
if PROFILE_SYNAPSE:
|
||||||
|
|
|
@ -252,10 +252,11 @@ class Config:
|
||||||
env = jinja2.Environment(loader=loader, autoescape=autoescape)
|
env = jinja2.Environment(loader=loader, autoescape=autoescape)
|
||||||
|
|
||||||
# Update the environment with our custom filters
|
# Update the environment with our custom filters
|
||||||
env.filters.update({"format_ts": _format_ts_filter})
|
|
||||||
if self.public_baseurl:
|
|
||||||
env.filters.update(
|
env.filters.update(
|
||||||
{"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
|
{
|
||||||
|
"format_ts": _format_ts_filter,
|
||||||
|
"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
|
|
@ -56,7 +56,7 @@ def json_error_to_config_error(
|
||||||
"""
|
"""
|
||||||
# copy `config_path` before modifying it.
|
# copy `config_path` before modifying it.
|
||||||
path = list(config_path)
|
path = list(config_path)
|
||||||
for p in list(e.path):
|
for p in list(e.absolute_path):
|
||||||
if isinstance(p, int):
|
if isinstance(p, int):
|
||||||
path.append("<item %i>" % p)
|
path.append("<item %i>" % p)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -40,7 +40,7 @@ class CasConfig(Config):
|
||||||
self.cas_required_attributes = {}
|
self.cas_required_attributes = {}
|
||||||
|
|
||||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||||
return """
|
return """\
|
||||||
# Enable Central Authentication Service (CAS) for registration and login.
|
# Enable Central Authentication Service (CAS) for registration and login.
|
||||||
#
|
#
|
||||||
cas_config:
|
cas_config:
|
||||||
|
|
|
@ -166,11 +166,6 @@ class EmailConfig(Config):
|
||||||
if not self.email_notif_from:
|
if not self.email_notif_from:
|
||||||
missing.append("email.notif_from")
|
missing.append("email.notif_from")
|
||||||
|
|
||||||
# public_baseurl is required to build password reset and validation links that
|
|
||||||
# will be emailed to users
|
|
||||||
if config.get("public_baseurl") is None:
|
|
||||||
missing.append("public_baseurl")
|
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
|
MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
|
||||||
|
@ -269,9 +264,6 @@ class EmailConfig(Config):
|
||||||
if not self.email_notif_from:
|
if not self.email_notif_from:
|
||||||
missing.append("email.notif_from")
|
missing.append("email.notif_from")
|
||||||
|
|
||||||
if config.get("public_baseurl") is None:
|
|
||||||
missing.append("public_baseurl")
|
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"email.enable_notifs is True but required keys are missing: %s"
|
"email.enable_notifs is True but required keys are missing: %s"
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2020 Quentin Gliech
|
# Copyright 2020 Quentin Gliech
|
||||||
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,8 +14,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import string
|
||||||
|
from typing import Iterable, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from synapse.config._util import validate_config
|
||||||
from synapse.python_dependencies import DependencyException, check_requirements
|
from synapse.python_dependencies import DependencyException, check_requirements
|
||||||
|
from synapse.types import Collection, JsonDict
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
|
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
@ -25,48 +34,285 @@ class OIDCConfig(Config):
|
||||||
section = "oidc"
|
section = "oidc"
|
||||||
|
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
self.oidc_enabled = False
|
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
|
||||||
|
if not self.oidc_providers:
|
||||||
oidc_config = config.get("oidc_config")
|
|
||||||
|
|
||||||
if not oidc_config or not oidc_config.get("enabled", False):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
check_requirements("oidc")
|
check_requirements("oidc")
|
||||||
except DependencyException as e:
|
except DependencyException as e:
|
||||||
raise ConfigError(e.message)
|
raise ConfigError(e.message) from e
|
||||||
|
|
||||||
public_baseurl = self.public_baseurl
|
public_baseurl = self.public_baseurl
|
||||||
if public_baseurl is None:
|
|
||||||
raise ConfigError("oidc_config requires a public_baseurl to be set")
|
|
||||||
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
||||||
|
|
||||||
self.oidc_enabled = True
|
@property
|
||||||
self.oidc_discover = oidc_config.get("discover", True)
|
def oidc_enabled(self) -> bool:
|
||||||
self.oidc_issuer = oidc_config["issuer"]
|
# OIDC is enabled if we have a provider
|
||||||
self.oidc_client_id = oidc_config["client_id"]
|
return bool(self.oidc_providers)
|
||||||
self.oidc_client_secret = oidc_config["client_secret"]
|
|
||||||
self.oidc_client_auth_method = oidc_config.get(
|
|
||||||
"client_auth_method", "client_secret_basic"
|
|
||||||
)
|
|
||||||
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
|
|
||||||
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
|
|
||||||
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
|
|
||||||
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
|
|
||||||
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
|
|
||||||
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
|
|
||||||
self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
|
|
||||||
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
|
|
||||||
|
|
||||||
|
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||||
|
return """\
|
||||||
|
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
|
||||||
|
# and login.
|
||||||
|
#
|
||||||
|
# Options for each entry include:
|
||||||
|
#
|
||||||
|
# idp_id: a unique identifier for this identity provider. Used internally
|
||||||
|
# by Synapse; should be a single word such as 'github'.
|
||||||
|
#
|
||||||
|
# Note that, if this is changed, users authenticating via that provider
|
||||||
|
# will no longer be recognised as the same user!
|
||||||
|
#
|
||||||
|
# idp_name: A user-facing name for this identity provider, which is used to
|
||||||
|
# offer the user a choice of login mechanisms.
|
||||||
|
#
|
||||||
|
# idp_icon: An optional icon for this identity provider, which is presented
|
||||||
|
# by identity picker pages. If given, must be an MXC URI of the format
|
||||||
|
# mxc://<server-name>/<media-id>
|
||||||
|
#
|
||||||
|
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||||
|
# to discover endpoints. Defaults to true.
|
||||||
|
#
|
||||||
|
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
|
||||||
|
# is enabled) to discover the provider's endpoints.
|
||||||
|
#
|
||||||
|
# client_id: Required. oauth2 client id to use.
|
||||||
|
#
|
||||||
|
# client_secret: Required. oauth2 client secret to use.
|
||||||
|
#
|
||||||
|
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||||
|
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||||
|
# 'none'.
|
||||||
|
#
|
||||||
|
# scopes: list of scopes to request. This should normally include the "openid"
|
||||||
|
# scope. Defaults to ["openid"].
|
||||||
|
#
|
||||||
|
# authorization_endpoint: the oauth2 authorization endpoint. Required if
|
||||||
|
# provider discovery is disabled.
|
||||||
|
#
|
||||||
|
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
|
||||||
|
# disabled.
|
||||||
|
#
|
||||||
|
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
|
||||||
|
# disabled and the 'openid' scope is not requested.
|
||||||
|
#
|
||||||
|
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
|
||||||
|
# the 'openid' scope is used.
|
||||||
|
#
|
||||||
|
# skip_verification: set to 'true' to skip metadata verification. Use this if
|
||||||
|
# you are connecting to a provider that is not OpenID Connect compliant.
|
||||||
|
# Defaults to false. Avoid this in production.
|
||||||
|
#
|
||||||
|
# user_profile_method: Whether to fetch the user profile from the userinfo
|
||||||
|
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
|
||||||
|
#
|
||||||
|
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
|
||||||
|
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
|
||||||
|
# userinfo endpoint.
|
||||||
|
#
|
||||||
|
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
|
||||||
|
# match a pre-existing account instead of failing. This could be used if
|
||||||
|
# switching from password logins to OIDC. Defaults to false.
|
||||||
|
#
|
||||||
|
# user_mapping_provider: Configuration for how attributes returned from a OIDC
|
||||||
|
# provider are mapped onto a matrix user. This setting has the following
|
||||||
|
# sub-properties:
|
||||||
|
#
|
||||||
|
# module: The class name of a custom mapping module. Default is
|
||||||
|
# {mapping_provider!r}.
|
||||||
|
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
||||||
|
# for information on implementing a custom mapping provider.
|
||||||
|
#
|
||||||
|
# config: Configuration for the mapping provider module. This section will
|
||||||
|
# be passed as a Python dictionary to the user mapping provider
|
||||||
|
# module's `parse_config` method.
|
||||||
|
#
|
||||||
|
# For the default provider, the following settings are available:
|
||||||
|
#
|
||||||
|
# sub: name of the claim containing a unique identifier for the
|
||||||
|
# user. Defaults to 'sub', which OpenID Connect compliant
|
||||||
|
# providers should provide.
|
||||||
|
#
|
||||||
|
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||||
|
# If this is not set, the user will be prompted to choose their
|
||||||
|
# own username.
|
||||||
|
#
|
||||||
|
# display_name_template: Jinja2 template for the display name to set
|
||||||
|
# on first login. If unset, no displayname will be set.
|
||||||
|
#
|
||||||
|
# extra_attributes: a map of Jinja2 templates for extra attributes
|
||||||
|
# to send back to the client during login.
|
||||||
|
# Note that these are non-standard and clients will ignore them
|
||||||
|
# without modifications.
|
||||||
|
#
|
||||||
|
# When rendering, the Jinja2 templates are given a 'user' variable,
|
||||||
|
# which is set to the claims returned by the UserInfo Endpoint and/or
|
||||||
|
# in the ID Token.
|
||||||
|
#
|
||||||
|
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
|
||||||
|
# for information on how to configure these options.
|
||||||
|
#
|
||||||
|
# For backwards compatibility, it is also possible to configure a single OIDC
|
||||||
|
# provider via an 'oidc_config' setting. This is now deprecated and admins are
|
||||||
|
# advised to migrate to the 'oidc_providers' format.
|
||||||
|
#
|
||||||
|
oidc_providers:
|
||||||
|
# Generic example
|
||||||
|
#
|
||||||
|
#- idp_id: my_idp
|
||||||
|
# idp_name: "My OpenID provider"
|
||||||
|
# discover: false
|
||||||
|
# issuer: "https://accounts.example.com/"
|
||||||
|
# client_id: "provided-by-your-issuer"
|
||||||
|
# client_secret: "provided-by-your-issuer"
|
||||||
|
# client_auth_method: client_secret_post
|
||||||
|
# scopes: ["openid", "profile"]
|
||||||
|
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||||
|
# token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||||
|
# userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||||
|
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||||
|
# skip_verification: true
|
||||||
|
|
||||||
|
# For use with Keycloak
|
||||||
|
#
|
||||||
|
#- idp_id: keycloak
|
||||||
|
# idp_name: Keycloak
|
||||||
|
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
|
||||||
|
# client_id: "synapse"
|
||||||
|
# client_secret: "copy secret generated in Keycloak UI"
|
||||||
|
# scopes: ["openid", "profile"]
|
||||||
|
|
||||||
|
# For use with Github
|
||||||
|
#
|
||||||
|
#- idp_id: google
|
||||||
|
# idp_name: Google
|
||||||
|
# discover: false
|
||||||
|
# issuer: "https://github.com/"
|
||||||
|
# client_id: "your-client-id" # TO BE FILLED
|
||||||
|
# client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
# authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||||
|
# token_endpoint: "https://github.com/login/oauth/access_token"
|
||||||
|
# userinfo_endpoint: "https://api.github.com/user"
|
||||||
|
# scopes: ["read:user"]
|
||||||
|
# user_mapping_provider:
|
||||||
|
# config:
|
||||||
|
# subject_claim: "id"
|
||||||
|
# localpart_template: "{{ user.login }}"
|
||||||
|
# display_name_template: "{{ user.name }}"
|
||||||
|
""".format(
|
||||||
|
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# jsonschema definition of the configuration settings for an oidc identity provider
|
||||||
|
OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["issuer", "client_id", "client_secret"],
|
||||||
|
"properties": {
|
||||||
|
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
|
||||||
|
"idp_name": {"type": "string"},
|
||||||
|
"idp_icon": {"type": "string"},
|
||||||
|
"discover": {"type": "boolean"},
|
||||||
|
"issuer": {"type": "string"},
|
||||||
|
"client_id": {"type": "string"},
|
||||||
|
"client_secret": {"type": "string"},
|
||||||
|
"client_auth_method": {
|
||||||
|
"type": "string",
|
||||||
|
# the following list is the same as the keys of
|
||||||
|
# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
|
||||||
|
# to avoid importing authlib here.
|
||||||
|
"enum": ["client_secret_basic", "client_secret_post", "none"],
|
||||||
|
},
|
||||||
|
"scopes": {"type": "array", "items": {"type": "string"}},
|
||||||
|
"authorization_endpoint": {"type": "string"},
|
||||||
|
"token_endpoint": {"type": "string"},
|
||||||
|
"userinfo_endpoint": {"type": "string"},
|
||||||
|
"jwks_uri": {"type": "string"},
|
||||||
|
"skip_verification": {"type": "boolean"},
|
||||||
|
"user_profile_method": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["auto", "userinfo_endpoint"],
|
||||||
|
},
|
||||||
|
"allow_existing_users": {"type": "boolean"},
|
||||||
|
"user_mapping_provider": {"type": ["object", "null"]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
|
||||||
|
OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
|
||||||
|
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# the `oidc_providers` list can either be None (as it is in the default config), or
|
||||||
|
# a list of provider configs, each of which requires an explicit ID and name.
|
||||||
|
OIDC_PROVIDER_LIST_SCHEMA = {
|
||||||
|
"oneOf": [
|
||||||
|
{"type": "null"},
|
||||||
|
{"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# the `oidc_config` setting can either be None (which it used to be in the default
|
||||||
|
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
|
||||||
|
# property.
|
||||||
|
#
|
||||||
|
# It's *possible* to represent this with jsonschema, but the resultant errors aren't
|
||||||
|
# particularly clear, so we just check for either an object or a null here, and do
|
||||||
|
# additional checks in the code.
|
||||||
|
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
|
||||||
|
|
||||||
|
# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
|
||||||
|
MAIN_CONFIG_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"oidc_config": OIDC_CONFIG_SCHEMA,
|
||||||
|
"oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
|
||||||
|
"""extract and parse the OIDC provider configs from the config dict
|
||||||
|
|
||||||
|
The configuration may contain either a single `oidc_config` object with an
|
||||||
|
`enabled: True` property, or a list of provider configurations under
|
||||||
|
`oidc_providers`, *or both*.
|
||||||
|
|
||||||
|
Returns a generator which yields the OidcProviderConfig objects
|
||||||
|
"""
|
||||||
|
validate_config(MAIN_CONFIG_SCHEMA, config, ())
|
||||||
|
|
||||||
|
for i, p in enumerate(config.get("oidc_providers") or []):
|
||||||
|
yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,)))
|
||||||
|
|
||||||
|
# for backwards-compatibility, it is also possible to provide a single "oidc_config"
|
||||||
|
# object with an "enabled: True" property.
|
||||||
|
oidc_config = config.get("oidc_config")
|
||||||
|
if oidc_config and oidc_config.get("enabled", False):
|
||||||
|
# MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
|
||||||
|
# it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
|
||||||
|
# above), so now we need to validate it.
|
||||||
|
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
|
||||||
|
yield _parse_oidc_config_dict(oidc_config, ("oidc_config",))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_oidc_config_dict(
|
||||||
|
oidc_config: JsonDict, config_path: Tuple[str, ...]
|
||||||
|
) -> "OidcProviderConfig":
|
||||||
|
"""Take the configuration dict and parse it into an OidcProviderConfig
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConfigError if the configuration is malformed.
|
||||||
|
"""
|
||||||
ump_config = oidc_config.get("user_mapping_provider", {})
|
ump_config = oidc_config.get("user_mapping_provider", {})
|
||||||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||||
ump_config.setdefault("config", {})
|
ump_config.setdefault("config", {})
|
||||||
|
|
||||||
(
|
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
|
||||||
self.oidc_user_mapping_provider_class,
|
ump_config, config_path + ("user_mapping_provider",)
|
||||||
self.oidc_user_mapping_provider_config,
|
)
|
||||||
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
|
|
||||||
|
|
||||||
# Ensure loaded user mapping module has defined all necessary methods
|
# Ensure loaded user mapping module has defined all necessary methods
|
||||||
required_methods = [
|
required_methods = [
|
||||||
|
@ -76,151 +322,124 @@ class OIDCConfig(Config):
|
||||||
missing_methods = [
|
missing_methods = [
|
||||||
method
|
method
|
||||||
for method in required_methods
|
for method in required_methods
|
||||||
if not hasattr(self.oidc_user_mapping_provider_class, method)
|
if not hasattr(user_mapping_provider_class, method)
|
||||||
]
|
]
|
||||||
if missing_methods:
|
if missing_methods:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"Class specified by oidc_config."
|
"Class %s is missing required "
|
||||||
"user_mapping_provider.module is missing required "
|
"methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
|
||||||
"methods: %s" % (", ".join(missing_methods),)
|
config_path + ("user_mapping_provider", "module"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
# MSC2858 will apply certain limits in what can be used as an IdP id, so let's
|
||||||
return """\
|
# enforce those limits now.
|
||||||
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
|
# TODO: factor out this stuff to a generic function
|
||||||
#
|
idp_id = oidc_config.get("idp_id", "oidc")
|
||||||
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
|
valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._")
|
||||||
# for some example configurations.
|
|
||||||
#
|
|
||||||
oidc_config:
|
|
||||||
# Uncomment the following to enable authorization against an OpenID Connect
|
|
||||||
# server. Defaults to false.
|
|
||||||
#
|
|
||||||
#enabled: true
|
|
||||||
|
|
||||||
# Uncomment the following to disable use of the OIDC discovery mechanism to
|
if any(c not in valid_idp_chars for c in idp_id):
|
||||||
# discover endpoints. Defaults to true.
|
raise ConfigError(
|
||||||
#
|
'idp_id may only contain a-z, 0-9, "-", ".", "_"',
|
||||||
#discover: false
|
config_path + ("idp_id",),
|
||||||
|
)
|
||||||
|
|
||||||
|
if idp_id[0] not in string.ascii_lowercase:
|
||||||
|
raise ConfigError(
|
||||||
|
"idp_id must start with a-z", config_path + ("idp_id",),
|
||||||
|
)
|
||||||
|
|
||||||
|
# MSC2858 also specifies that the idp_icon must be a valid MXC uri
|
||||||
|
idp_icon = oidc_config.get("idp_icon")
|
||||||
|
if idp_icon is not None:
|
||||||
|
try:
|
||||||
|
parse_and_validate_mxc_uri(idp_icon)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ConfigError(
|
||||||
|
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return OidcProviderConfig(
|
||||||
|
idp_id=idp_id,
|
||||||
|
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||||
|
idp_icon=idp_icon,
|
||||||
|
discover=oidc_config.get("discover", True),
|
||||||
|
issuer=oidc_config["issuer"],
|
||||||
|
client_id=oidc_config["client_id"],
|
||||||
|
client_secret=oidc_config["client_secret"],
|
||||||
|
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
|
||||||
|
scopes=oidc_config.get("scopes", ["openid"]),
|
||||||
|
authorization_endpoint=oidc_config.get("authorization_endpoint"),
|
||||||
|
token_endpoint=oidc_config.get("token_endpoint"),
|
||||||
|
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
|
||||||
|
jwks_uri=oidc_config.get("jwks_uri"),
|
||||||
|
skip_verification=oidc_config.get("skip_verification", False),
|
||||||
|
user_profile_method=oidc_config.get("user_profile_method", "auto"),
|
||||||
|
allow_existing_users=oidc_config.get("allow_existing_users", False),
|
||||||
|
user_mapping_provider_class=user_mapping_provider_class,
|
||||||
|
user_mapping_provider_config=user_mapping_provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class OidcProviderConfig:
|
||||||
|
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
||||||
|
# table, as well as the query/path parameter used in the login protocol.
|
||||||
|
idp_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
# user-facing name for this identity provider.
|
||||||
|
idp_name = attr.ib(type=str)
|
||||||
|
|
||||||
|
# Optional MXC URI for icon for this IdP.
|
||||||
|
idp_icon = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||||
|
discover = attr.ib(type=bool)
|
||||||
|
|
||||||
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
||||||
# discover the provider's endpoints.
|
# discover the provider's endpoints.
|
||||||
#
|
issuer = attr.ib(type=str)
|
||||||
# Required if 'enabled' is true.
|
|
||||||
#
|
|
||||||
#issuer: "https://accounts.example.com/"
|
|
||||||
|
|
||||||
# oauth2 client id to use.
|
# oauth2 client id to use
|
||||||
#
|
client_id = attr.ib(type=str)
|
||||||
# Required if 'enabled' is true.
|
|
||||||
#
|
|
||||||
#client_id: "provided-by-your-issuer"
|
|
||||||
|
|
||||||
# oauth2 client secret to use.
|
# oauth2 client secret to use
|
||||||
#
|
client_secret = attr.ib(type=str)
|
||||||
# Required if 'enabled' is true.
|
|
||||||
#
|
|
||||||
#client_secret: "provided-by-your-issuer"
|
|
||||||
|
|
||||||
# auth method to use when exchanging the token.
|
# auth method to use when exchanging the token.
|
||||||
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
|
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
||||||
# 'none'.
|
# 'none'.
|
||||||
#
|
client_auth_method = attr.ib(type=str)
|
||||||
#client_auth_method: client_secret_post
|
|
||||||
|
|
||||||
# list of scopes to request. This should normally include the "openid" scope.
|
# list of scopes to request
|
||||||
# Defaults to ["openid"].
|
scopes = attr.ib(type=Collection[str])
|
||||||
#
|
|
||||||
#scopes: ["openid", "profile"]
|
|
||||||
|
|
||||||
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
|
# the oauth2 authorization endpoint. Required if discovery is disabled.
|
||||||
#
|
authorization_endpoint = attr.ib(type=Optional[str])
|
||||||
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
|
||||||
|
|
||||||
# the oauth2 token endpoint. Required if provider discovery is disabled.
|
# the oauth2 token endpoint. Required if discovery is disabled.
|
||||||
#
|
token_endpoint = attr.ib(type=Optional[str])
|
||||||
#token_endpoint: "https://accounts.example.com/oauth2/token"
|
|
||||||
|
|
||||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
||||||
# "openid" scope is not requested.
|
# "openid" scope is not requested.
|
||||||
#
|
userinfo_endpoint = attr.ib(type=Optional[str])
|
||||||
#userinfo_endpoint: "https://accounts.example.com/userinfo"
|
|
||||||
|
|
||||||
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
||||||
# "openid" scope is used.
|
# "openid" scope is used.
|
||||||
#
|
jwks_uri = attr.ib(type=Optional[str])
|
||||||
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
|
||||||
|
|
||||||
# Uncomment to skip metadata verification. Defaults to false.
|
# Whether to skip metadata verification
|
||||||
#
|
skip_verification = attr.ib(type=bool)
|
||||||
# Use this if you are connecting to a provider that is not OpenID Connect
|
|
||||||
# compliant.
|
|
||||||
# Avoid this in production.
|
|
||||||
#
|
|
||||||
#skip_verification: true
|
|
||||||
|
|
||||||
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
||||||
# values are: "auto" or "userinfo_endpoint".
|
# values are: "auto" or "userinfo_endpoint".
|
||||||
#
|
user_profile_method = attr.ib(type=str)
|
||||||
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
|
|
||||||
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
|
|
||||||
#
|
|
||||||
#user_profile_method: "userinfo_endpoint"
|
|
||||||
|
|
||||||
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
|
# whether to allow a user logging in via OIDC to match a pre-existing account
|
||||||
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
|
# instead of failing
|
||||||
#
|
allow_existing_users = attr.ib(type=bool)
|
||||||
#allow_existing_users: true
|
|
||||||
|
|
||||||
# An external module can be provided here as a custom solution to mapping
|
# the class of the user mapping provider
|
||||||
# attributes returned from a OIDC provider onto a matrix user.
|
user_mapping_provider_class = attr.ib(type=Type)
|
||||||
#
|
|
||||||
user_mapping_provider:
|
|
||||||
# The custom module's class. Uncomment to use a custom module.
|
|
||||||
# Default is {mapping_provider!r}.
|
|
||||||
#
|
|
||||||
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
|
||||||
# for information on implementing a custom mapping provider.
|
|
||||||
#
|
|
||||||
#module: mapping_provider.OidcMappingProvider
|
|
||||||
|
|
||||||
# Custom configuration values for the module. This section will be passed as
|
# the config of the user mapping provider
|
||||||
# a Python dictionary to the user mapping provider module's `parse_config`
|
user_mapping_provider_config = attr.ib()
|
||||||
# method.
|
|
||||||
#
|
|
||||||
# The examples below are intended for the default provider: they should be
|
|
||||||
# changed if using a custom provider.
|
|
||||||
#
|
|
||||||
config:
|
|
||||||
# name of the claim containing a unique identifier for the user.
|
|
||||||
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
|
|
||||||
#
|
|
||||||
#subject_claim: "sub"
|
|
||||||
|
|
||||||
# Jinja2 template for the localpart of the MXID.
|
|
||||||
#
|
|
||||||
# When rendering, this template is given the following variables:
|
|
||||||
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
|
|
||||||
# Token
|
|
||||||
#
|
|
||||||
# If this is not set, the user will be prompted to choose their
|
|
||||||
# own username.
|
|
||||||
#
|
|
||||||
#localpart_template: "{{{{ user.preferred_username }}}}"
|
|
||||||
|
|
||||||
# Jinja2 template for the display name to set on first login.
|
|
||||||
#
|
|
||||||
# If unset, no displayname will be set.
|
|
||||||
#
|
|
||||||
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
|
|
||||||
|
|
||||||
# Jinja2 templates for extra attributes to send back to the client during
|
|
||||||
# login.
|
|
||||||
#
|
|
||||||
# Note that these are non-standard and clients will ignore them without modifications.
|
|
||||||
#
|
|
||||||
#extra_attributes:
|
|
||||||
#birthdate: "{{{{ user.birthdate }}}}"
|
|
||||||
""".format(
|
|
||||||
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
|
||||||
)
|
|
||||||
|
|
|
@ -14,14 +14,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from distutils.util import strtobool
|
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
from synapse.api.constants import RoomCreationPreset
|
from synapse.api.constants import RoomCreationPreset
|
||||||
from synapse.config._base import Config, ConfigError
|
from synapse.config._base import Config, ConfigError
|
||||||
from synapse.types import RoomAlias, UserID
|
from synapse.types import RoomAlias, UserID
|
||||||
from synapse.util.stringutils import random_string_with_symbols
|
from synapse.util.stringutils import random_string_with_symbols, strtobool
|
||||||
|
|
||||||
|
|
||||||
class AccountValidityConfig(Config):
|
class AccountValidityConfig(Config):
|
||||||
|
@ -50,10 +49,6 @@ class AccountValidityConfig(Config):
|
||||||
|
|
||||||
self.startup_job_max_delta = self.period * 10.0 / 100.0
|
self.startup_job_max_delta = self.period * 10.0 / 100.0
|
||||||
|
|
||||||
if self.renew_by_email_enabled:
|
|
||||||
if "public_baseurl" not in synapse_config:
|
|
||||||
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
|
|
||||||
|
|
||||||
template_dir = config.get("template_dir")
|
template_dir = config.get("template_dir")
|
||||||
|
|
||||||
if not template_dir:
|
if not template_dir:
|
||||||
|
@ -86,12 +81,12 @@ class RegistrationConfig(Config):
|
||||||
section = "registration"
|
section = "registration"
|
||||||
|
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
self.enable_registration = bool(
|
self.enable_registration = strtobool(
|
||||||
strtobool(str(config.get("enable_registration", False)))
|
str(config.get("enable_registration", False))
|
||||||
)
|
)
|
||||||
if "disable_registration" in config:
|
if "disable_registration" in config:
|
||||||
self.enable_registration = not bool(
|
self.enable_registration = not strtobool(
|
||||||
strtobool(str(config["disable_registration"]))
|
str(config["disable_registration"])
|
||||||
)
|
)
|
||||||
|
|
||||||
self.account_validity = AccountValidityConfig(
|
self.account_validity = AccountValidityConfig(
|
||||||
|
@ -110,13 +105,6 @@ class RegistrationConfig(Config):
|
||||||
account_threepid_delegates = config.get("account_threepid_delegates") or {}
|
account_threepid_delegates = config.get("account_threepid_delegates") or {}
|
||||||
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
|
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
|
||||||
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
|
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
|
||||||
if self.account_threepid_delegate_msisdn and not self.public_baseurl:
|
|
||||||
raise ConfigError(
|
|
||||||
"The configuration option `public_baseurl` is required if "
|
|
||||||
"`account_threepid_delegate.msisdn` is set, such that "
|
|
||||||
"clients know where to submit validation tokens to. Please "
|
|
||||||
"configure `public_baseurl`."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.default_identity_server = config.get("default_identity_server")
|
self.default_identity_server = config.get("default_identity_server")
|
||||||
self.allow_guest_access = config.get("allow_guest_access", False)
|
self.allow_guest_access = config.get("allow_guest_access", False)
|
||||||
|
@ -241,8 +229,9 @@ class RegistrationConfig(Config):
|
||||||
# send an email to the account's email address with a renewal link. By
|
# send an email to the account's email address with a renewal link. By
|
||||||
# default, no such emails are sent.
|
# default, no such emails are sent.
|
||||||
#
|
#
|
||||||
# If you enable this setting, you will also need to fill out the 'email' and
|
# If you enable this setting, you will also need to fill out the 'email'
|
||||||
# 'public_baseurl' configuration sections.
|
# configuration section. You should also check that 'public_baseurl' is set
|
||||||
|
# correctly.
|
||||||
#
|
#
|
||||||
#renew_at: 1w
|
#renew_at: 1w
|
||||||
|
|
||||||
|
@ -333,8 +322,7 @@ class RegistrationConfig(Config):
|
||||||
# The identity server which we suggest that clients should use when users log
|
# The identity server which we suggest that clients should use when users log
|
||||||
# in on this server.
|
# in on this server.
|
||||||
#
|
#
|
||||||
# (By default, no suggestion is made, so it is left up to the client.
|
# (By default, no suggestion is made, so it is left up to the client.)
|
||||||
# This setting is ignored unless public_baseurl is also set.)
|
|
||||||
#
|
#
|
||||||
#default_identity_server: https://matrix.org
|
#default_identity_server: https://matrix.org
|
||||||
|
|
||||||
|
@ -359,8 +347,6 @@ class RegistrationConfig(Config):
|
||||||
# by the Matrix Identity Service API specification:
|
# by the Matrix Identity Service API specification:
|
||||||
# https://matrix.org/docs/spec/identity_service/latest
|
# https://matrix.org/docs/spec/identity_service/latest
|
||||||
#
|
#
|
||||||
# If a delegate is specified, the config option public_baseurl must also be filled out.
|
|
||||||
#
|
|
||||||
account_threepid_delegates:
|
account_threepid_delegates:
|
||||||
#email: https://example.com # Delegate email sending to example.com
|
#email: https://example.com # Delegate email sending to example.com
|
||||||
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
||||||
|
|
|
@ -189,8 +189,6 @@ class SAML2Config(Config):
|
||||||
import saml2
|
import saml2
|
||||||
|
|
||||||
public_baseurl = self.public_baseurl
|
public_baseurl = self.public_baseurl
|
||||||
if public_baseurl is None:
|
|
||||||
raise ConfigError("saml2_config requires a public_baseurl to be set")
|
|
||||||
|
|
||||||
if self.saml2_grandfathered_mxid_source_attribute:
|
if self.saml2_grandfathered_mxid_source_attribute:
|
||||||
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
|
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
|
||||||
|
|
|
@ -26,7 +26,7 @@ import yaml
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
@ -161,7 +161,11 @@ class ServerConfig(Config):
|
||||||
self.print_pidfile = config.get("print_pidfile")
|
self.print_pidfile = config.get("print_pidfile")
|
||||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||||
self.public_baseurl = config.get("public_baseurl")
|
self.public_baseurl = config.get("public_baseurl") or "https://%s/" % (
|
||||||
|
self.server_name,
|
||||||
|
)
|
||||||
|
if self.public_baseurl[-1] != "/":
|
||||||
|
self.public_baseurl += "/"
|
||||||
|
|
||||||
# Whether to enable user presence.
|
# Whether to enable user presence.
|
||||||
self.use_presence = config.get("use_presence", True)
|
self.use_presence = config.get("use_presence", True)
|
||||||
|
@ -317,9 +321,6 @@ class ServerConfig(Config):
|
||||||
# Always blacklist 0.0.0.0, ::
|
# Always blacklist 0.0.0.0, ::
|
||||||
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
|
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
|
||||||
|
|
||||||
if self.public_baseurl is not None:
|
|
||||||
if self.public_baseurl[-1] != "/":
|
|
||||||
self.public_baseurl += "/"
|
|
||||||
self.start_pushers = config.get("start_pushers", True)
|
self.start_pushers = config.get("start_pushers", True)
|
||||||
|
|
||||||
# (undocumented) option for torturing the worker-mode replication a bit,
|
# (undocumented) option for torturing the worker-mode replication a bit,
|
||||||
|
@ -740,11 +741,16 @@ class ServerConfig(Config):
|
||||||
#
|
#
|
||||||
#web_client_location: https://riot.example.com/
|
#web_client_location: https://riot.example.com/
|
||||||
|
|
||||||
# The public-facing base URL that clients use to access this HS
|
# The public-facing base URL that clients use to access this Homeserver (not
|
||||||
# (not including _matrix/...). This is the same URL a user would
|
# including _matrix/...). This is the same URL a user might enter into the
|
||||||
# enter into the 'custom HS URL' field on their client. If you
|
# 'Custom Homeserver URL' field on their client. If you use Synapse with a
|
||||||
# use synapse with a reverse proxy, this should be the URL to reach
|
# reverse proxy, this should be the URL to reach Synapse via the proxy.
|
||||||
# synapse via the proxy.
|
# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
|
||||||
|
# 'listeners' below).
|
||||||
|
#
|
||||||
|
# If this is left unset, it defaults to 'https://<server_name>/'. (Note that
|
||||||
|
# that will not work unless you configure Synapse or a reverse-proxy to listen
|
||||||
|
# on port 443.)
|
||||||
#
|
#
|
||||||
#public_baseurl: https://example.com/
|
#public_baseurl: https://example.com/
|
||||||
|
|
||||||
|
|
|
@ -31,18 +31,22 @@ class SSOConfig(Config):
|
||||||
|
|
||||||
# Read templates from disk
|
# Read templates from disk
|
||||||
(
|
(
|
||||||
|
self.sso_login_idp_picker_template,
|
||||||
self.sso_redirect_confirm_template,
|
self.sso_redirect_confirm_template,
|
||||||
self.sso_auth_confirm_template,
|
self.sso_auth_confirm_template,
|
||||||
self.sso_error_template,
|
self.sso_error_template,
|
||||||
sso_account_deactivated_template,
|
sso_account_deactivated_template,
|
||||||
sso_auth_success_template,
|
sso_auth_success_template,
|
||||||
|
self.sso_auth_bad_user_template,
|
||||||
) = self.read_templates(
|
) = self.read_templates(
|
||||||
[
|
[
|
||||||
|
"sso_login_idp_picker.html",
|
||||||
"sso_redirect_confirm.html",
|
"sso_redirect_confirm.html",
|
||||||
"sso_auth_confirm.html",
|
"sso_auth_confirm.html",
|
||||||
"sso_error.html",
|
"sso_error.html",
|
||||||
"sso_account_deactivated.html",
|
"sso_account_deactivated.html",
|
||||||
"sso_auth_success.html",
|
"sso_auth_success.html",
|
||||||
|
"sso_auth_bad_user.html",
|
||||||
],
|
],
|
||||||
template_dir,
|
template_dir,
|
||||||
)
|
)
|
||||||
|
@ -60,9 +64,6 @@ class SSOConfig(Config):
|
||||||
# gracefully to the client). This would make it pointless to ask the user for
|
# gracefully to the client). This would make it pointless to ask the user for
|
||||||
# confirmation, since the URL the confirmation page would be showing wouldn't be
|
# confirmation, since the URL the confirmation page would be showing wouldn't be
|
||||||
# the client's.
|
# the client's.
|
||||||
# public_baseurl is an optional setting, so we only add the fallback's URL to the
|
|
||||||
# list if it's provided (because we can't figure out what that URL is otherwise).
|
|
||||||
if self.public_baseurl:
|
|
||||||
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
|
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
|
||||||
self.sso_client_whitelist.append(login_fallback_url)
|
self.sso_client_whitelist.append(login_fallback_url)
|
||||||
|
|
||||||
|
@ -82,9 +83,9 @@ class SSOConfig(Config):
|
||||||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||||
# hostname: "https://my.client/".
|
# hostname: "https://my.client/".
|
||||||
#
|
#
|
||||||
# If public_baseurl is set, then the login fallback page (used by clients
|
# The login fallback page (used by clients that don't natively support the
|
||||||
# that don't natively support the required login flows) is whitelisted in
|
# required login flows) is automatically whitelisted in addition to any URLs
|
||||||
# addition to any URLs in this list.
|
# in this list.
|
||||||
#
|
#
|
||||||
# By default, this list is empty.
|
# By default, this list is empty.
|
||||||
#
|
#
|
||||||
|
@ -98,6 +99,31 @@ class SSOConfig(Config):
|
||||||
#
|
#
|
||||||
# Synapse will look for the following templates in this directory:
|
# Synapse will look for the following templates in this directory:
|
||||||
#
|
#
|
||||||
|
# * HTML page to prompt the user to choose an Identity Provider during
|
||||||
|
# login: 'sso_login_idp_picker.html'.
|
||||||
|
#
|
||||||
|
# This is only used if multiple SSO Identity Providers are configured.
|
||||||
|
#
|
||||||
|
# When rendering, this template is given the following variables:
|
||||||
|
# * redirect_url: the URL that the user will be redirected to after
|
||||||
|
# login. Needs manual escaping (see
|
||||||
|
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||||
|
#
|
||||||
|
# * server_name: the homeserver's name.
|
||||||
|
#
|
||||||
|
# * providers: a list of available Identity Providers. Each element is
|
||||||
|
# an object with the following attributes:
|
||||||
|
# * idp_id: unique identifier for the IdP
|
||||||
|
# * idp_name: user-facing name for the IdP
|
||||||
|
#
|
||||||
|
# The rendered HTML page should contain a form which submits its results
|
||||||
|
# back as a GET request, with the following query parameters:
|
||||||
|
#
|
||||||
|
# * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
|
||||||
|
# to the template)
|
||||||
|
#
|
||||||
|
# * idp: the 'idp_id' of the chosen IDP.
|
||||||
|
#
|
||||||
# * HTML page for a confirmation step before redirecting back to the client
|
# * HTML page for a confirmation step before redirecting back to the client
|
||||||
# with the login token: 'sso_redirect_confirm.html'.
|
# with the login token: 'sso_redirect_confirm.html'.
|
||||||
#
|
#
|
||||||
|
@ -133,6 +159,14 @@ class SSOConfig(Config):
|
||||||
#
|
#
|
||||||
# This template has no additional variables.
|
# This template has no additional variables.
|
||||||
#
|
#
|
||||||
|
# * HTML page shown after a user-interactive authentication session which
|
||||||
|
# does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
|
||||||
|
#
|
||||||
|
# When rendering, this template is given the following variables:
|
||||||
|
# * server_name: the homeserver's name.
|
||||||
|
# * user_id_to_verify: the MXID of the user that we are trying to
|
||||||
|
# validate.
|
||||||
|
#
|
||||||
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
||||||
# attempts to login: 'sso_account_deactivated.html'.
|
# attempts to login: 'sso_account_deactivated.html'.
|
||||||
#
|
#
|
||||||
|
|
|
@ -53,6 +53,15 @@ class WriterLocations:
|
||||||
default=["master"], type=List[str], converter=_instance_to_list_converter
|
default=["master"], type=List[str], converter=_instance_to_list_converter
|
||||||
)
|
)
|
||||||
typing = attr.ib(default="master", type=str)
|
typing = attr.ib(default="master", type=str)
|
||||||
|
to_device = attr.ib(
|
||||||
|
default=["master"], type=List[str], converter=_instance_to_list_converter,
|
||||||
|
)
|
||||||
|
account_data = attr.ib(
|
||||||
|
default=["master"], type=List[str], converter=_instance_to_list_converter,
|
||||||
|
)
|
||||||
|
receipts = attr.ib(
|
||||||
|
default=["master"], type=List[str], converter=_instance_to_list_converter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkerConfig(Config):
|
class WorkerConfig(Config):
|
||||||
|
@ -124,7 +133,7 @@ class WorkerConfig(Config):
|
||||||
|
|
||||||
# Check that the configured writers for events and typing also appears in
|
# Check that the configured writers for events and typing also appears in
|
||||||
# `instance_map`.
|
# `instance_map`.
|
||||||
for stream in ("events", "typing"):
|
for stream in ("events", "typing", "to_device", "account_data", "receipts"):
|
||||||
instances = _instance_to_list_converter(getattr(self.writers, stream))
|
instances = _instance_to_list_converter(getattr(self.writers, stream))
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
if instance != "master" and instance not in self.instance_map:
|
if instance != "master" and instance not in self.instance_map:
|
||||||
|
@ -133,6 +142,21 @@ class WorkerConfig(Config):
|
||||||
% (instance, stream)
|
% (instance, stream)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(self.writers.to_device) != 1:
|
||||||
|
raise ConfigError(
|
||||||
|
"Must only specify one instance to handle `to_device` messages."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.writers.account_data) != 1:
|
||||||
|
raise ConfigError(
|
||||||
|
"Must only specify one instance to handle `account_data` messages."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.writers.receipts) != 1:
|
||||||
|
raise ConfigError(
|
||||||
|
"Must only specify one instance to handle `receipts` messages."
|
||||||
|
)
|
||||||
|
|
||||||
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
|
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
|
||||||
|
|
||||||
# Whether this worker should run background tasks or not.
|
# Whether this worker should run background tasks or not.
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
from distutils.util import strtobool
|
|
||||||
from typing import Dict, Optional, Tuple, Type
|
from typing import Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
|
@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
|
||||||
from synapse.types import JsonDict, RoomStreamToken
|
from synapse.types import JsonDict, RoomStreamToken
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
from synapse.util.frozenutils import freeze
|
from synapse.util.frozenutils import freeze
|
||||||
|
from synapse.util.stringutils import strtobool
|
||||||
|
|
||||||
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
|
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
|
||||||
# bugs where we accidentally share e.g. signature dicts. However, converting a
|
# bugs where we accidentally share e.g. signature dicts. However, converting a
|
||||||
|
@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
|
||||||
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
|
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
|
||||||
# for the sake of tests, it is set here while it cannot be configured on the
|
# for the sake of tests, it is set here while it cannot be configured on the
|
||||||
# homeserver object itself.
|
# homeserver object itself.
|
||||||
|
|
||||||
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
|
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||||
"state_key",
|
"state_key",
|
||||||
"depth",
|
"depth",
|
||||||
"prev_events",
|
"prev_events",
|
||||||
"prev_state",
|
|
||||||
"auth_events",
|
"auth_events",
|
||||||
"origin",
|
"origin",
|
||||||
"origin_server_ts",
|
"origin_server_ts",
|
||||||
"membership",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Room versions from before MSC2176 had additional allowed keys.
|
||||||
|
if not room_version.msc2176_redaction_rules:
|
||||||
|
allowed_keys.extend(["prev_state", "membership"])
|
||||||
|
|
||||||
event_type = event_dict["type"]
|
event_type = event_dict["type"]
|
||||||
|
|
||||||
new_content = {}
|
new_content = {}
|
||||||
|
@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||||
if event_type == EventTypes.Member:
|
if event_type == EventTypes.Member:
|
||||||
add_fields("membership")
|
add_fields("membership")
|
||||||
elif event_type == EventTypes.Create:
|
elif event_type == EventTypes.Create:
|
||||||
|
# MSC2176 rules state that create events cannot be redacted.
|
||||||
|
if room_version.msc2176_redaction_rules:
|
||||||
|
return event_dict
|
||||||
|
|
||||||
add_fields("creator")
|
add_fields("creator")
|
||||||
elif event_type == EventTypes.JoinRules:
|
elif event_type == EventTypes.JoinRules:
|
||||||
add_fields("join_rule")
|
add_fields("join_rule")
|
||||||
|
@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||||
"kick",
|
"kick",
|
||||||
"redact",
|
"redact",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if room_version.msc2176_redaction_rules:
|
||||||
|
add_fields("invite")
|
||||||
|
|
||||||
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
|
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
|
||||||
add_fields("aliases")
|
add_fields("aliases")
|
||||||
elif event_type == EventTypes.RoomHistoryVisibility:
|
elif event_type == EventTypes.RoomHistoryVisibility:
|
||||||
add_fields("history_visibility")
|
add_fields("history_visibility")
|
||||||
|
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
|
||||||
|
add_fields("redacts")
|
||||||
|
|
||||||
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
|
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
@ -48,7 +49,6 @@ from synapse.events import EventBase
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
from synapse.federation.persistence import TransactionActions
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
from synapse.http.endpoint import parse_server_name
|
|
||||||
from synapse.http.servlet import assert_params_in_dict
|
from synapse.http.servlet import assert_params_in_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
|
@ -65,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.util.stringutils import parse_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -860,8 +861,10 @@ class FederationHandlerRegistry:
|
||||||
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
||||||
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
|
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
|
||||||
|
|
||||||
# Map from type to instance name that we should route EDU handling to.
|
# Map from type to instance names that we should route EDU handling to.
|
||||||
self._edu_type_to_instance = {} # type: Dict[str, str]
|
# We randomly choose one instance from the list to route to for each new
|
||||||
|
# EDU received.
|
||||||
|
self._edu_type_to_instance = {} # type: Dict[str, List[str]]
|
||||||
|
|
||||||
def register_edu_handler(
|
def register_edu_handler(
|
||||||
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
|
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
|
||||||
|
@ -905,7 +908,12 @@ class FederationHandlerRegistry:
|
||||||
def register_instance_for_edu(self, edu_type: str, instance_name: str):
|
def register_instance_for_edu(self, edu_type: str, instance_name: str):
|
||||||
"""Register that the EDU handler is on a different instance than master.
|
"""Register that the EDU handler is on a different instance than master.
|
||||||
"""
|
"""
|
||||||
self._edu_type_to_instance[edu_type] = instance_name
|
self._edu_type_to_instance[edu_type] = [instance_name]
|
||||||
|
|
||||||
|
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
|
||||||
|
"""Register that the EDU handler is on multiple instances.
|
||||||
|
"""
|
||||||
|
self._edu_type_to_instance[edu_type] = instance_names
|
||||||
|
|
||||||
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
||||||
if not self.config.use_presence and edu_type == "m.presence":
|
if not self.config.use_presence and edu_type == "m.presence":
|
||||||
|
@ -924,8 +932,11 @@ class FederationHandlerRegistry:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if we can route it somewhere else that isn't us
|
# Check if we can route it somewhere else that isn't us
|
||||||
route_to = self._edu_type_to_instance.get(edu_type, "master")
|
instances = self._edu_type_to_instance.get(edu_type, ["master"])
|
||||||
if route_to != self._instance_name:
|
if self._instance_name not in instances:
|
||||||
|
# Pick an instance randomly so that we don't overload one.
|
||||||
|
route_to = random.choice(instances)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_edu(
|
await self._send_edu(
|
||||||
instance_name=route_to,
|
instance_name=route_to,
|
||||||
|
|
|
@ -28,7 +28,6 @@ from synapse.api.urls import (
|
||||||
FEDERATION_V1_PREFIX,
|
FEDERATION_V1_PREFIX,
|
||||||
FEDERATION_V2_PREFIX,
|
FEDERATION_V2_PREFIX,
|
||||||
)
|
)
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_boolean_from_args,
|
parse_boolean_from_args,
|
||||||
|
@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
|
||||||
)
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
||||||
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,14 +13,157 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import random
|
||||||
from typing import TYPE_CHECKING, List, Tuple
|
from typing import TYPE_CHECKING, List, Tuple
|
||||||
|
|
||||||
|
from synapse.replication.http.account_data import (
|
||||||
|
ReplicationAddTagRestServlet,
|
||||||
|
ReplicationRemoveTagRestServlet,
|
||||||
|
ReplicationRoomAccountDataRestServlet,
|
||||||
|
ReplicationUserAccountDataRestServlet,
|
||||||
|
)
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.app.homeserver import HomeServer
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
class AccountDataHandler:
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._store = hs.get_datastore()
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
self._notifier = hs.get_notifier()
|
||||||
|
|
||||||
|
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
|
||||||
|
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
|
||||||
|
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
|
||||||
|
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
|
||||||
|
self._account_data_writers = hs.config.worker.writers.account_data
|
||||||
|
|
||||||
|
async def add_account_data_to_room(
|
||||||
|
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
|
||||||
|
) -> int:
|
||||||
|
"""Add some account_data to a room for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user to add a tag for.
|
||||||
|
room_id: The room to add a tag for.
|
||||||
|
account_data_type: The type of account_data to add.
|
||||||
|
content: A json object to associate with the tag.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum stream ID.
|
||||||
|
"""
|
||||||
|
if self._instance_name in self._account_data_writers:
|
||||||
|
max_stream_id = await self._store.add_account_data_to_room(
|
||||||
|
user_id, room_id, account_data_type, content
|
||||||
|
)
|
||||||
|
|
||||||
|
self._notifier.on_new_event(
|
||||||
|
"account_data_key", max_stream_id, users=[user_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
return max_stream_id
|
||||||
|
else:
|
||||||
|
response = await self._room_data_client(
|
||||||
|
instance_name=random.choice(self._account_data_writers),
|
||||||
|
user_id=user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
account_data_type=account_data_type,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
return response["max_stream_id"]
|
||||||
|
|
||||||
|
async def add_account_data_for_user(
|
||||||
|
self, user_id: str, account_data_type: str, content: JsonDict
|
||||||
|
) -> int:
|
||||||
|
"""Add some account_data to a room for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user to add a tag for.
|
||||||
|
account_data_type: The type of account_data to add.
|
||||||
|
content: A json object to associate with the tag.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum stream ID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._instance_name in self._account_data_writers:
|
||||||
|
max_stream_id = await self._store.add_account_data_for_user(
|
||||||
|
user_id, account_data_type, content
|
||||||
|
)
|
||||||
|
|
||||||
|
self._notifier.on_new_event(
|
||||||
|
"account_data_key", max_stream_id, users=[user_id]
|
||||||
|
)
|
||||||
|
return max_stream_id
|
||||||
|
else:
|
||||||
|
response = await self._user_data_client(
|
||||||
|
instance_name=random.choice(self._account_data_writers),
|
||||||
|
user_id=user_id,
|
||||||
|
account_data_type=account_data_type,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
return response["max_stream_id"]
|
||||||
|
|
||||||
|
async def add_tag_to_room(
|
||||||
|
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
||||||
|
) -> int:
|
||||||
|
"""Add a tag to a room for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user to add a tag for.
|
||||||
|
room_id: The room to add a tag for.
|
||||||
|
tag: The tag name to add.
|
||||||
|
content: A json object to associate with the tag.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next account data ID.
|
||||||
|
"""
|
||||||
|
if self._instance_name in self._account_data_writers:
|
||||||
|
max_stream_id = await self._store.add_tag_to_room(
|
||||||
|
user_id, room_id, tag, content
|
||||||
|
)
|
||||||
|
|
||||||
|
self._notifier.on_new_event(
|
||||||
|
"account_data_key", max_stream_id, users=[user_id]
|
||||||
|
)
|
||||||
|
return max_stream_id
|
||||||
|
else:
|
||||||
|
response = await self._add_tag_client(
|
||||||
|
instance_name=random.choice(self._account_data_writers),
|
||||||
|
user_id=user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
tag=tag,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
return response["max_stream_id"]
|
||||||
|
|
||||||
|
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
|
||||||
|
"""Remove a tag from a room for a user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next account data ID.
|
||||||
|
"""
|
||||||
|
if self._instance_name in self._account_data_writers:
|
||||||
|
max_stream_id = await self._store.remove_tag_from_room(
|
||||||
|
user_id, room_id, tag
|
||||||
|
)
|
||||||
|
|
||||||
|
self._notifier.on_new_event(
|
||||||
|
"account_data_key", max_stream_id, users=[user_id]
|
||||||
|
)
|
||||||
|
return max_stream_id
|
||||||
|
else:
|
||||||
|
response = await self._remove_tag_client(
|
||||||
|
instance_name=random.choice(self._account_data_writers),
|
||||||
|
user_id=user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
tag=tag,
|
||||||
|
)
|
||||||
|
return response["max_stream_id"]
|
||||||
|
|
||||||
|
|
||||||
class AccountDataEventSource:
|
class AccountDataEventSource:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
|
@ -49,8 +49,13 @@ from synapse.api.errors import (
|
||||||
UserDeactivatedError,
|
UserDeactivatedError,
|
||||||
)
|
)
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
|
from synapse.handlers._base import BaseHandler
|
||||||
|
from synapse.handlers.ui_auth import (
|
||||||
|
INTERACTIVE_AUTH_CHECKERS,
|
||||||
|
UIAuthSessionDataConstants,
|
||||||
|
)
|
||||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
|
from synapse.http import get_request_user_agent
|
||||||
from synapse.http.server import finish_request, respond_with_html
|
from synapse.http.server import finish_request, respond_with_html
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
|
@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
from ._base import BaseHandler
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.app.homeserver import HomeServer
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
@ -260,10 +263,6 @@ class AuthHandler(BaseHandler):
|
||||||
# authenticating for an operation to occur on their account.
|
# authenticating for an operation to occur on their account.
|
||||||
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
|
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
|
||||||
|
|
||||||
# The following template is shown after a successful user interactive
|
|
||||||
# authentication session. It tells the user they can close the window.
|
|
||||||
self._sso_auth_success_template = hs.config.sso_auth_success_template
|
|
||||||
|
|
||||||
# The following template is shown during the SSO authentication process if
|
# The following template is shown during the SSO authentication process if
|
||||||
# the account is deactivated.
|
# the account is deactivated.
|
||||||
self._sso_account_deactivated_template = (
|
self._sso_account_deactivated_template = (
|
||||||
|
@ -284,7 +283,6 @@ class AuthHandler(BaseHandler):
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
request_body: Dict[str, Any],
|
request_body: Dict[str, Any],
|
||||||
clientip: str,
|
|
||||||
description: str,
|
description: str,
|
||||||
) -> Tuple[dict, Optional[str]]:
|
) -> Tuple[dict, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -301,8 +299,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
request_body: The body of the request sent by the client
|
request_body: The body of the request sent by the client
|
||||||
|
|
||||||
clientip: The IP address of the client.
|
|
||||||
|
|
||||||
description: A human readable string to be displayed to the user that
|
description: A human readable string to be displayed to the user that
|
||||||
describes the operation happening on their account.
|
describes the operation happening on their account.
|
||||||
|
|
||||||
|
@ -338,10 +334,10 @@ class AuthHandler(BaseHandler):
|
||||||
request_body.pop("auth", None)
|
request_body.pop("auth", None)
|
||||||
return request_body, None
|
return request_body, None
|
||||||
|
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
# Check if we should be ratelimited due to too many previous failed attempts
|
# Check if we should be ratelimited due to too many previous failed attempts
|
||||||
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
|
||||||
|
|
||||||
# build a list of supported flows
|
# build a list of supported flows
|
||||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||||
|
@ -349,13 +345,16 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
flows = [[login_type] for login_type in supported_ui_auth_types]
|
flows = [[login_type] for login_type in supported_ui_auth_types]
|
||||||
|
|
||||||
|
def get_new_session_data() -> JsonDict:
|
||||||
|
return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result, params, session_id = await self.check_ui_auth(
|
result, params, session_id = await self.check_ui_auth(
|
||||||
flows, request, request_body, clientip, description
|
flows, request, request_body, description, get_new_session_data,
|
||||||
)
|
)
|
||||||
except LoginError:
|
except LoginError:
|
||||||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
||||||
self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
|
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# find the completed login type
|
# find the completed login type
|
||||||
|
@ -363,14 +362,14 @@ class AuthHandler(BaseHandler):
|
||||||
if login_type not in result:
|
if login_type not in result:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
user_id = result[login_type]
|
validated_user_id = result[login_type]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# this can't happen
|
# this can't happen
|
||||||
raise Exception("check_auth returned True but no successful login type")
|
raise Exception("check_auth returned True but no successful login type")
|
||||||
|
|
||||||
# check that the UI auth matched the access token
|
# check that the UI auth matched the access token
|
||||||
if user_id != requester.user.to_string():
|
if validated_user_id != requester_user_id:
|
||||||
raise AuthError(403, "Invalid auth")
|
raise AuthError(403, "Invalid auth")
|
||||||
|
|
||||||
# Note that the access token has been validated.
|
# Note that the access token has been validated.
|
||||||
|
@ -402,13 +401,9 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||||
# from sso to mxid.
|
# from sso to mxid.
|
||||||
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
|
if await self.hs.get_sso_handler().get_identity_providers_for_user(
|
||||||
if await self.store.get_external_ids_by_user(user.to_string()):
|
user.to_string()
|
||||||
ui_auth_types.add(LoginType.SSO)
|
):
|
||||||
|
|
||||||
# Our CAS impl does not (yet) correctly register users in user_external_ids,
|
|
||||||
# so always offer that if it's available.
|
|
||||||
if self.hs.config.cas.cas_enabled:
|
|
||||||
ui_auth_types.add(LoginType.SSO)
|
ui_auth_types.add(LoginType.SSO)
|
||||||
|
|
||||||
return ui_auth_types
|
return ui_auth_types
|
||||||
|
@ -426,8 +421,8 @@ class AuthHandler(BaseHandler):
|
||||||
flows: List[List[str]],
|
flows: List[List[str]],
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
clientdict: Dict[str, Any],
|
clientdict: Dict[str, Any],
|
||||||
clientip: str,
|
|
||||||
description: str,
|
description: str,
|
||||||
|
get_new_session_data: Optional[Callable[[], JsonDict]] = None,
|
||||||
) -> Tuple[dict, dict, str]:
|
) -> Tuple[dict, dict, str]:
|
||||||
"""
|
"""
|
||||||
Takes a dictionary sent by the client in the login / registration
|
Takes a dictionary sent by the client in the login / registration
|
||||||
|
@ -448,11 +443,16 @@ class AuthHandler(BaseHandler):
|
||||||
clientdict: The dictionary from the client root level, not the
|
clientdict: The dictionary from the client root level, not the
|
||||||
'auth' key: this method prompts for auth if none is sent.
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
|
|
||||||
clientip: The IP address of the client.
|
|
||||||
|
|
||||||
description: A human readable string to be displayed to the user that
|
description: A human readable string to be displayed to the user that
|
||||||
describes the operation happening on their account.
|
describes the operation happening on their account.
|
||||||
|
|
||||||
|
get_new_session_data:
|
||||||
|
an optional callback which will be called when starting a new session.
|
||||||
|
it should return data to be stored as part of the session.
|
||||||
|
|
||||||
|
The keys of the returned data should be entries in
|
||||||
|
UIAuthSessionDataConstants.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (creds, params, session_id).
|
A tuple of (creds, params, session_id).
|
||||||
|
|
||||||
|
@ -480,10 +480,15 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# If there's no session ID, create a new session.
|
# If there's no session ID, create a new session.
|
||||||
if not sid:
|
if not sid:
|
||||||
|
new_session_data = get_new_session_data() if get_new_session_data else {}
|
||||||
|
|
||||||
session = await self.store.create_ui_auth_session(
|
session = await self.store.create_ui_auth_session(
|
||||||
clientdict, uri, method, description
|
clientdict, uri, method, description
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for k, v in new_session_data.items():
|
||||||
|
await self.set_session_data(session.session_id, k, v)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
session = await self.store.get_ui_auth_session(sid)
|
session = await self.store.get_ui_auth_session(sid)
|
||||||
|
@ -539,7 +544,8 @@ class AuthHandler(BaseHandler):
|
||||||
# authentication flow.
|
# authentication flow.
|
||||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||||
|
|
||||||
user_agent = request.get_user_agent("")
|
user_agent = get_request_user_agent(request)
|
||||||
|
clientip = request.getClientIP()
|
||||||
|
|
||||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||||
session.session_id, user_agent, clientip
|
session.session_id, user_agent, clientip
|
||||||
|
@ -644,7 +650,8 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The ID of this session as returned from check_auth
|
session_id: The ID of this session as returned from check_auth
|
||||||
key: The key to store the data under
|
key: The key to store the data under. An entry from
|
||||||
|
UIAuthSessionDataConstants.
|
||||||
value: The data to store
|
value: The data to store
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -660,7 +667,8 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The ID of this session as returned from check_auth
|
session_id: The ID of this session as returned from check_auth
|
||||||
key: The key to store the data under
|
key: The key the data was stored under. An entry from
|
||||||
|
UIAuthSessionDataConstants.
|
||||||
default: Value to return if the key has not been set
|
default: Value to return if the key has not been set
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -1334,12 +1342,12 @@ class AuthHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get the HTML for the SSO redirect confirmation page.
|
Get the HTML for the SSO redirect confirmation page.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
redirect_url: The URL to redirect to the SSO provider.
|
request: The incoming HTTP request
|
||||||
session_id: The user interactive authentication session ID.
|
session_id: The user interactive authentication session ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1349,31 +1357,39 @@ class AuthHandler(BaseHandler):
|
||||||
session = await self.store.get_ui_auth_session(session_id)
|
session = await self.store.get_ui_auth_session(session_id)
|
||||||
except StoreError:
|
except StoreError:
|
||||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||||
|
|
||||||
|
user_id_to_verify = await self.get_session_data(
|
||||||
|
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
||||||
|
) # type: str
|
||||||
|
|
||||||
|
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
|
||||||
|
user_id_to_verify
|
||||||
|
)
|
||||||
|
|
||||||
|
if not idps:
|
||||||
|
# we checked that the user had some remote identities before offering an SSO
|
||||||
|
# flow, so either it's been deleted or the client has requested SSO despite
|
||||||
|
# it not being offered.
|
||||||
|
raise SynapseError(400, "User has no SSO identities")
|
||||||
|
|
||||||
|
# for now, just pick one
|
||||||
|
idp_id, sso_auth_provider = next(iter(idps.items()))
|
||||||
|
if len(idps) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"User %r has previously logged in with multiple SSO IdPs; arbitrarily "
|
||||||
|
"picking %r",
|
||||||
|
user_id_to_verify,
|
||||||
|
idp_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
redirect_url = await sso_auth_provider.handle_redirect_request(
|
||||||
|
request, None, session_id
|
||||||
|
)
|
||||||
|
|
||||||
return self._sso_auth_confirm_template.render(
|
return self._sso_auth_confirm_template.render(
|
||||||
description=session.description, redirect_url=redirect_url,
|
description=session.description, redirect_url=redirect_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete_sso_ui_auth(
|
|
||||||
self, registered_user_id: str, session_id: str, request: Request,
|
|
||||||
):
|
|
||||||
"""Having figured out a mxid for this user, complete the HTTP request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
registered_user_id: The registered user ID to complete SSO login for.
|
|
||||||
session_id: The ID of the user-interactive auth session.
|
|
||||||
request: The request to complete.
|
|
||||||
"""
|
|
||||||
# Mark the stage of the authentication as successful.
|
|
||||||
# Save the user who authenticated with SSO, this will be used to ensure
|
|
||||||
# that the account be modified is also the person who logged in.
|
|
||||||
await self.store.mark_ui_auth_stage_complete(
|
|
||||||
session_id, LoginType.SSO, registered_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render the HTML and return.
|
|
||||||
html = self._sso_auth_success_template
|
|
||||||
respond_with_html(request, 200, html)
|
|
||||||
|
|
||||||
async def complete_sso_login(
|
async def complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
|
@ -1488,8 +1504,8 @@ class AuthHandler(BaseHandler):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_query_param_to_url(url: str, param_name: str, param: Any):
|
def add_query_param_to_url(url: str, param_name: str, param: Any):
|
||||||
url_parts = list(urllib.parse.urlparse(url))
|
url_parts = list(urllib.parse.urlparse(url))
|
||||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
|
||||||
query.update({param_name: param})
|
query.append((param_name, param))
|
||||||
url_parts[4] = urllib.parse.urlencode(query)
|
url_parts[4] = urllib.parse.urlencode(query)
|
||||||
return urllib.parse.urlunparse(url_parts)
|
return urllib.parse.urlunparse(url_parts)
|
||||||
|
|
||||||
|
|
|
@ -75,10 +75,19 @@ class CasHandler:
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
# identifier for the external_ids table
|
# identifier for the external_ids table
|
||||||
self._auth_provider_id = "cas"
|
self.idp_id = "cas"
|
||||||
|
|
||||||
|
# user-facing name of this auth provider
|
||||||
|
self.idp_name = "CAS"
|
||||||
|
|
||||||
|
# we do not currently support icons for CAS auth, but this is required by
|
||||||
|
# the SsoIdentityProvider protocol type.
|
||||||
|
self.idp_icon = None
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
||||||
def _build_service_param(self, args: Dict[str, str]) -> str:
|
def _build_service_param(self, args: Dict[str, str]) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a value to use as the "service" parameter when redirecting or
|
Generates a value to use as the "service" parameter when redirecting or
|
||||||
|
@ -105,7 +114,7 @@ class CasHandler:
|
||||||
Args:
|
Args:
|
||||||
ticket: The CAS ticket from the client.
|
ticket: The CAS ticket from the client.
|
||||||
service_args: Additional arguments to include in the service URL.
|
service_args: Additional arguments to include in the service URL.
|
||||||
Should be the same as those passed to `get_redirect_url`.
|
Should be the same as those passed to `handle_redirect_request`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
CasError: If there's an error parsing the CAS response.
|
CasError: If there's an error parsing the CAS response.
|
||||||
|
@ -184,16 +193,31 @@ class CasHandler:
|
||||||
|
|
||||||
return CasResponse(user, attributes)
|
return CasResponse(user, attributes)
|
||||||
|
|
||||||
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
|
async def handle_redirect_request(
|
||||||
"""
|
self,
|
||||||
Generates a URL for the CAS server where the client should be redirected.
|
request: SynapseRequest,
|
||||||
|
client_redirect_url: Optional[bytes],
|
||||||
|
ui_auth_session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generates a URL for the CAS server where the client should be redirected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_args: Additional arguments to include in the final redirect URL.
|
request: the incoming HTTP request
|
||||||
|
client_redirect_url: the URL that we should redirect the
|
||||||
|
client to after login (or None for UI Auth).
|
||||||
|
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||||
|
None if this is a login).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The URL to redirect the client to.
|
URL to redirect to
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if ui_auth_session_id:
|
||||||
|
service_args = {"session": ui_auth_session_id}
|
||||||
|
else:
|
||||||
|
assert client_redirect_url
|
||||||
|
service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
|
||||||
|
|
||||||
args = urllib.parse.urlencode(
|
args = urllib.parse.urlencode(
|
||||||
{"service": self._build_service_param(service_args)}
|
{"service": self._build_service_param(service_args)}
|
||||||
)
|
)
|
||||||
|
@ -275,7 +299,7 @@ class CasHandler:
|
||||||
# first check if we're doing a UIA
|
# first check if we're doing a UIA
|
||||||
if session:
|
if session:
|
||||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
self._auth_provider_id, cas_response.username, session, request,
|
self.idp_id, cas_response.username, session, request,
|
||||||
)
|
)
|
||||||
|
|
||||||
# otherwise, we're handling a login request.
|
# otherwise, we're handling a login request.
|
||||||
|
@ -375,7 +399,7 @@ class CasHandler:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await self._sso_handler.complete_sso_login_request(
|
await self._sso_handler.complete_sso_login_request(
|
||||||
self._auth_provider_id,
|
self.idp_id,
|
||||||
cas_response.username,
|
cas_response.username,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
|
|
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import Requester, UserID, create_requester
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
|
||||||
self._device_handler = hs.get_device_handler()
|
self._device_handler = hs.get_device_handler()
|
||||||
self._room_member_handler = hs.get_room_member_handler()
|
self._room_member_handler = hs.get_room_member_handler()
|
||||||
self._identity_handler = hs.get_identity_handler()
|
self._identity_handler = hs.get_identity_handler()
|
||||||
|
self._profile_handler = hs.get_profile_handler()
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
|
@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
|
||||||
self._account_validity_enabled = hs.config.account_validity.enabled
|
self._account_validity_enabled = hs.config.account_validity.enabled
|
||||||
|
|
||||||
async def deactivate_account(
|
async def deactivate_account(
|
||||||
self, user_id: str, erase_data: bool, id_server: Optional[str] = None
|
self,
|
||||||
|
user_id: str,
|
||||||
|
erase_data: bool,
|
||||||
|
requester: Requester,
|
||||||
|
id_server: Optional[str] = None,
|
||||||
|
by_admin: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Deactivate a user's account
|
"""Deactivate a user's account
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: ID of user to be deactivated
|
user_id: ID of user to be deactivated
|
||||||
erase_data: whether to GDPR-erase the user's data
|
erase_data: whether to GDPR-erase the user's data
|
||||||
|
requester: The user attempting to make this change.
|
||||||
id_server: Use the given identity server when unbinding
|
id_server: Use the given identity server when unbinding
|
||||||
any threepids. If None then will attempt to unbind using the
|
any threepids. If None then will attempt to unbind using the
|
||||||
identity server specified when binding (if known).
|
identity server specified when binding (if known).
|
||||||
|
by_admin: Whether this change was made by an administrator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if identity server supports removing threepids, otherwise False.
|
True if identity server supports removing threepids, otherwise False.
|
||||||
|
@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
|
||||||
|
|
||||||
# Mark the user as erased, if they asked for that
|
# Mark the user as erased, if they asked for that
|
||||||
if erase_data:
|
if erase_data:
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
# Remove avatar URL from this user
|
||||||
|
await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
|
||||||
|
# Remove displayname from this user
|
||||||
|
await self._profile_handler.set_displayname(user, requester, "", by_admin)
|
||||||
|
|
||||||
logger.info("Marking %s as erased", user_id)
|
logger.info("Marking %s as erased", user_id)
|
||||||
await self.store.mark_user_erased(user_id)
|
await self.store.mark_user_erased(user_id)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.logging.opentracing import (
|
||||||
set_tag,
|
set_tag,
|
||||||
start_active_span,
|
start_active_span,
|
||||||
)
|
)
|
||||||
|
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
@ -44,13 +45,37 @@ class DeviceMessageHandler:
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.federation = hs.get_federation_sender()
|
|
||||||
|
|
||||||
|
# We only need to poke the federation sender explicitly if its on the
|
||||||
|
# same instance. Other federation sender instances will get notified by
|
||||||
|
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
|
||||||
|
# in the to-device replication stream.
|
||||||
|
self.federation_sender = None
|
||||||
|
if hs.should_send_federation():
|
||||||
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
|
||||||
|
# If we can handle the to device EDUs we do so, otherwise we route them
|
||||||
|
# to the appropriate worker.
|
||||||
|
if hs.get_instance_name() in hs.config.worker.writers.to_device:
|
||||||
hs.get_federation_registry().register_edu_handler(
|
hs.get_federation_registry().register_edu_handler(
|
||||||
"m.direct_to_device", self.on_direct_to_device_edu
|
"m.direct_to_device", self.on_direct_to_device_edu
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
hs.get_federation_registry().register_instances_for_edu(
|
||||||
|
"m.direct_to_device", hs.config.worker.writers.to_device,
|
||||||
|
)
|
||||||
|
|
||||||
self._device_list_updater = hs.get_device_handler().device_list_updater
|
# The handler to call when we think a user's device list might be out of
|
||||||
|
# sync. We do all device list resyncing on the master instance, so if
|
||||||
|
# we're on a worker we hit the device resync replication API.
|
||||||
|
if hs.config.worker.worker_app is None:
|
||||||
|
self._user_device_resync = (
|
||||||
|
hs.get_device_handler().device_list_updater.user_device_resync
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
|
||||||
|
hs
|
||||||
|
)
|
||||||
|
|
||||||
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
|
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
|
||||||
local_messages = {}
|
local_messages = {}
|
||||||
|
@ -138,9 +163,7 @@ class DeviceMessageHandler:
|
||||||
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
|
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
|
||||||
|
|
||||||
# Immediately attempt a resync in the background
|
# Immediately attempt a resync in the background
|
||||||
run_in_background(
|
run_in_background(self._user_device_resync, user_id=sender_user_id)
|
||||||
self._device_list_updater.user_device_resync, sender_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_device_message(
|
async def send_device_message(
|
||||||
self,
|
self,
|
||||||
|
@ -195,7 +218,8 @@ class DeviceMessageHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
log_kv({"remote_messages": remote_messages})
|
log_kv({"remote_messages": remote_messages})
|
||||||
|
if self.federation_sender:
|
||||||
for destination in remote_messages.keys():
|
for destination in remote_messages.keys():
|
||||||
# Enqueue a new federation transaction to send the new
|
# Enqueue a new federation transaction to send the new
|
||||||
# device messages to each remote destination.
|
# device messages to each remote destination.
|
||||||
self.federation.send_device_messages(destination)
|
self.federation_sender.send_device_messages(destination)
|
||||||
|
|
|
@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler):
|
||||||
except RequestTimedOutError:
|
except RequestTimedOutError:
|
||||||
raise SynapseError(500, "Timed out contacting identity server")
|
raise SynapseError(500, "Timed out contacting identity server")
|
||||||
|
|
||||||
assert self.hs.config.public_baseurl
|
|
||||||
|
|
||||||
# we need to tell the client to send the token back to us, since it doesn't
|
# we need to tell the client to send the token back to us, since it doesn't
|
||||||
# otherwise know where to send it, so add submit_url response parameter
|
# otherwise know where to send it, so add submit_url response parameter
|
||||||
# (see also MSC2078)
|
# (see also MSC2078)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -35,7 +35,7 @@ from typing_extensions import TypedDict
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.config.oidc_config import OidcProviderConfig
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
@ -71,6 +71,144 @@ JWK = Dict[str, str]
|
||||||
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
||||||
|
|
||||||
|
|
||||||
|
class OidcHandler:
|
||||||
|
"""Handles requests related to the OpenID Connect login flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
provider_confs = hs.config.oidc.oidc_providers
|
||||||
|
# we should not have been instantiated if there is no configured provider.
|
||||||
|
assert provider_confs
|
||||||
|
|
||||||
|
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||||
|
self._providers = {
|
||||||
|
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
|
||||||
|
} # type: Dict[str, OidcProvider]
|
||||||
|
|
||||||
|
async def load_metadata(self) -> None:
|
||||||
|
"""Validate the config and load the metadata from the remote endpoint.
|
||||||
|
|
||||||
|
Called at startup to ensure we have everything we need.
|
||||||
|
"""
|
||||||
|
for idp_id, p in self._providers.items():
|
||||||
|
try:
|
||||||
|
await p.load_metadata()
|
||||||
|
await p.load_jwks()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
"Error while initialising OIDC provider %r" % (idp_id,)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||||
|
"""Handle an incoming request to /_synapse/oidc/callback
|
||||||
|
|
||||||
|
Since we might want to display OIDC-related errors in a user-friendly
|
||||||
|
way, we don't raise SynapseError from here. Instead, we call
|
||||||
|
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||||
|
|
||||||
|
Most of the OpenID Connect logic happens here:
|
||||||
|
|
||||||
|
- first, we check if there was any error returned by the provider and
|
||||||
|
display it
|
||||||
|
- then we fetch the session cookie, decode and verify it
|
||||||
|
- the ``state`` query parameter should match with the one stored in the
|
||||||
|
session cookie
|
||||||
|
|
||||||
|
Once we know the session is legit, we then delegate to the OIDC Provider
|
||||||
|
implementation, which will exchange the code with the provider and complete the
|
||||||
|
login/authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: the incoming request from the browser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The provider might redirect with an error.
|
||||||
|
# In that case, just display it as-is.
|
||||||
|
if b"error" in request.args:
|
||||||
|
# error response from the auth server. see:
|
||||||
|
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
||||||
|
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||||
|
error = request.args[b"error"][0].decode()
|
||||||
|
description = request.args.get(b"error_description", [b""])[0].decode()
|
||||||
|
|
||||||
|
# Most of the errors returned by the provider could be due by
|
||||||
|
# either the provider misbehaving or Synapse being misconfigured.
|
||||||
|
# The only exception of that is "access_denied", where the user
|
||||||
|
# probably cancelled the login flow. In other cases, log those errors.
|
||||||
|
if error != "access_denied":
|
||||||
|
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||||
|
|
||||||
|
self._sso_handler.render_error(request, error, description)
|
||||||
|
return
|
||||||
|
|
||||||
|
# otherwise, it is presumably a successful response. see:
|
||||||
|
# https://tools.ietf.org/html/rfc6749#section-4.1.2
|
||||||
|
|
||||||
|
# Fetch the session cookie
|
||||||
|
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||||
|
if session is None:
|
||||||
|
logger.info("No session cookie found")
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request, "missing_session", "No session cookie found"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove the cookie. There is a good chance that if the callback failed
|
||||||
|
# once, it will fail next time and the code will already be exchanged.
|
||||||
|
# Removing it early avoids spamming the provider with token requests.
|
||||||
|
request.addCookie(
|
||||||
|
SESSION_COOKIE_NAME,
|
||||||
|
b"",
|
||||||
|
path="/_synapse/oidc",
|
||||||
|
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
||||||
|
httpOnly=True,
|
||||||
|
sameSite="lax",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for the state query parameter
|
||||||
|
if b"state" not in request.args:
|
||||||
|
logger.info("State parameter is missing")
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request, "invalid_request", "State parameter is missing"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
state = request.args[b"state"][0].decode()
|
||||||
|
|
||||||
|
# Deserialize the session token and verify it.
|
||||||
|
try:
|
||||||
|
session_data = self._token_generator.verify_oidc_session_token(
|
||||||
|
session, state
|
||||||
|
)
|
||||||
|
except (MacaroonDeserializationException, ValueError) as e:
|
||||||
|
logger.exception("Invalid session")
|
||||||
|
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||||
|
return
|
||||||
|
except MacaroonInvalidSignatureException as e:
|
||||||
|
logger.exception("Could not verify session")
|
||||||
|
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
oidc_provider = self._providers.get(session_data.idp_id)
|
||||||
|
if not oidc_provider:
|
||||||
|
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
|
||||||
|
self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
|
||||||
|
return
|
||||||
|
|
||||||
|
if b"code" not in request.args:
|
||||||
|
logger.info("Code parameter is missing")
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request, "invalid_request", "Code parameter is missing"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
code = request.args[b"code"][0].decode()
|
||||||
|
|
||||||
|
await oidc_provider.handle_oidc_callback(request, session_data, code)
|
||||||
|
|
||||||
|
|
||||||
class OidcError(Exception):
|
class OidcError(Exception):
|
||||||
"""Used to catch errors when calling the token_endpoint
|
"""Used to catch errors when calling the token_endpoint
|
||||||
"""
|
"""
|
||||||
|
@ -85,44 +223,61 @@ class OidcError(Exception):
|
||||||
return self.error
|
return self.error
|
||||||
|
|
||||||
|
|
||||||
class OidcHandler(BaseHandler):
|
class OidcProvider:
|
||||||
"""Handles requests related to the OpenID Connect login flow.
|
"""Wraps the config for a single OIDC IdentityProvider
|
||||||
|
|
||||||
|
Provides methods for handling redirect requests and callbacks via that particular
|
||||||
|
IdP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(
|
||||||
super().__init__(hs)
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
token_generator: "OidcSessionTokenGenerator",
|
||||||
|
provider: OidcProviderConfig,
|
||||||
|
):
|
||||||
|
self._store = hs.get_datastore()
|
||||||
|
|
||||||
|
self._token_generator = token_generator
|
||||||
|
|
||||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
|
||||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
self._scopes = provider.scopes
|
||||||
|
self._user_profile_method = provider.user_profile_method
|
||||||
self._client_auth = ClientAuth(
|
self._client_auth = ClientAuth(
|
||||||
hs.config.oidc_client_id,
|
provider.client_id, provider.client_secret, provider.client_auth_method,
|
||||||
hs.config.oidc_client_secret,
|
|
||||||
hs.config.oidc_client_auth_method,
|
|
||||||
) # type: ClientAuth
|
) # type: ClientAuth
|
||||||
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
|
self._client_auth_method = provider.client_auth_method
|
||||||
self._provider_metadata = OpenIDProviderMetadata(
|
self._provider_metadata = OpenIDProviderMetadata(
|
||||||
issuer=hs.config.oidc_issuer,
|
issuer=provider.issuer,
|
||||||
authorization_endpoint=hs.config.oidc_authorization_endpoint,
|
authorization_endpoint=provider.authorization_endpoint,
|
||||||
token_endpoint=hs.config.oidc_token_endpoint,
|
token_endpoint=provider.token_endpoint,
|
||||||
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
|
userinfo_endpoint=provider.userinfo_endpoint,
|
||||||
jwks_uri=hs.config.oidc_jwks_uri,
|
jwks_uri=provider.jwks_uri,
|
||||||
) # type: OpenIDProviderMetadata
|
) # type: OpenIDProviderMetadata
|
||||||
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
|
self._provider_needs_discovery = provider.discover
|
||||||
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
|
self._user_mapping_provider = provider.user_mapping_provider_class(
|
||||||
hs.config.oidc_user_mapping_provider_config
|
provider.user_mapping_provider_config
|
||||||
) # type: OidcMappingProvider
|
)
|
||||||
self._skip_verification = hs.config.oidc_skip_verification # type: bool
|
self._skip_verification = provider.skip_verification
|
||||||
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
|
self._allow_existing_users = provider.allow_existing_users
|
||||||
|
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._server_name = hs.config.server_name # type: str
|
self._server_name = hs.config.server_name # type: str
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
|
||||||
|
|
||||||
# identifier for the external_ids table
|
# identifier for the external_ids table
|
||||||
self._auth_provider_id = "oidc"
|
self.idp_id = provider.idp_id
|
||||||
|
|
||||||
|
# user-facing name of this auth provider
|
||||||
|
self.idp_name = provider.idp_name
|
||||||
|
|
||||||
|
# MXC URI for icon for this auth provider
|
||||||
|
self.idp_icon = provider.idp_icon
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
||||||
def _validate_metadata(self):
|
def _validate_metadata(self):
|
||||||
"""Verifies the provider metadata.
|
"""Verifies the provider metadata.
|
||||||
|
|
||||||
|
@ -475,7 +630,7 @@ class OidcHandler(BaseHandler):
|
||||||
async def handle_redirect_request(
|
async def handle_redirect_request(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
client_redirect_url: bytes,
|
client_redirect_url: Optional[bytes],
|
||||||
ui_auth_session_id: Optional[str] = None,
|
ui_auth_session_id: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle an incoming request to /login/sso/redirect
|
"""Handle an incoming request to /login/sso/redirect
|
||||||
|
@ -499,7 +654,7 @@ class OidcHandler(BaseHandler):
|
||||||
request: the incoming request from the browser.
|
request: the incoming request from the browser.
|
||||||
We'll respond to it with a redirect and a cookie.
|
We'll respond to it with a redirect and a cookie.
|
||||||
client_redirect_url: the URL that we should redirect the client to
|
client_redirect_url: the URL that we should redirect the client to
|
||||||
when everything is done
|
when everything is done (or None for UI Auth)
|
||||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||||
None if this is a login).
|
None if this is a login).
|
||||||
|
|
||||||
|
@ -511,11 +666,17 @@ class OidcHandler(BaseHandler):
|
||||||
state = generate_token()
|
state = generate_token()
|
||||||
nonce = generate_token()
|
nonce = generate_token()
|
||||||
|
|
||||||
cookie = self._generate_oidc_session_token(
|
if not client_redirect_url:
|
||||||
|
client_redirect_url = b""
|
||||||
|
|
||||||
|
cookie = self._token_generator.generate_oidc_session_token(
|
||||||
state=state,
|
state=state,
|
||||||
|
session_data=OidcSessionData(
|
||||||
|
idp_id=self.idp_id,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
client_redirect_url=client_redirect_url.decode(),
|
client_redirect_url=client_redirect_url.decode(),
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
request.addCookie(
|
request.addCookie(
|
||||||
SESSION_COOKIE_NAME,
|
SESSION_COOKIE_NAME,
|
||||||
|
@ -538,22 +699,16 @@ class OidcHandler(BaseHandler):
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
async def handle_oidc_callback(
|
||||||
|
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
|
||||||
|
) -> None:
|
||||||
"""Handle an incoming request to /_synapse/oidc/callback
|
"""Handle an incoming request to /_synapse/oidc/callback
|
||||||
|
|
||||||
Since we might want to display OIDC-related errors in a user-friendly
|
By this time we have already validated the session on the synapse side, and
|
||||||
way, we don't raise SynapseError from here. Instead, we call
|
now need to do the provider-specific operations. This includes:
|
||||||
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
|
||||||
|
|
||||||
Most of the OpenID Connect logic happens here:
|
- exchange the code with the provider using the ``token_endpoint`` (see
|
||||||
|
``_exchange_code``)
|
||||||
- first, we check if there was any error returned by the provider and
|
|
||||||
display it
|
|
||||||
- then we fetch the session cookie, decode and verify it
|
|
||||||
- the ``state`` query parameter should match with the one stored in the
|
|
||||||
session cookie
|
|
||||||
- once we known this session is legit, exchange the code with the
|
|
||||||
provider using the ``token_endpoint`` (see ``_exchange_code``)
|
|
||||||
- once we have the token, use it to either extract the UserInfo from
|
- once we have the token, use it to either extract the UserInfo from
|
||||||
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
|
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
|
||||||
to fetch UserInfo from the ``userinfo_endpoint``
|
to fetch UserInfo from the ``userinfo_endpoint``
|
||||||
|
@ -563,88 +718,12 @@ class OidcHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: the incoming request from the browser.
|
request: the incoming request from the browser.
|
||||||
|
session_data: the session data, extracted from our cookie
|
||||||
|
code: The authorization code we got from the callback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The provider might redirect with an error.
|
|
||||||
# In that case, just display it as-is.
|
|
||||||
if b"error" in request.args:
|
|
||||||
# error response from the auth server. see:
|
|
||||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
|
||||||
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
|
||||||
error = request.args[b"error"][0].decode()
|
|
||||||
description = request.args.get(b"error_description", [b""])[0].decode()
|
|
||||||
|
|
||||||
# Most of the errors returned by the provider could be due by
|
|
||||||
# either the provider misbehaving or Synapse being misconfigured.
|
|
||||||
# The only exception of that is "access_denied", where the user
|
|
||||||
# probably cancelled the login flow. In other cases, log those errors.
|
|
||||||
if error != "access_denied":
|
|
||||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
|
||||||
|
|
||||||
self._sso_handler.render_error(request, error, description)
|
|
||||||
return
|
|
||||||
|
|
||||||
# otherwise, it is presumably a successful response. see:
|
|
||||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2
|
|
||||||
|
|
||||||
# Fetch the session cookie
|
|
||||||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
|
||||||
if session is None:
|
|
||||||
logger.info("No session cookie found")
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request, "missing_session", "No session cookie found"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove the cookie. There is a good chance that if the callback failed
|
|
||||||
# once, it will fail next time and the code will already be exchanged.
|
|
||||||
# Removing it early avoids spamming the provider with token requests.
|
|
||||||
request.addCookie(
|
|
||||||
SESSION_COOKIE_NAME,
|
|
||||||
b"",
|
|
||||||
path="/_synapse/oidc",
|
|
||||||
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
|
||||||
httpOnly=True,
|
|
||||||
sameSite="lax",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for the state query parameter
|
|
||||||
if b"state" not in request.args:
|
|
||||||
logger.info("State parameter is missing")
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request, "invalid_request", "State parameter is missing"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
state = request.args[b"state"][0].decode()
|
|
||||||
|
|
||||||
# Deserialize the session token and verify it.
|
|
||||||
try:
|
|
||||||
(
|
|
||||||
nonce,
|
|
||||||
client_redirect_url,
|
|
||||||
ui_auth_session_id,
|
|
||||||
) = self._verify_oidc_session_token(session, state)
|
|
||||||
except MacaroonDeserializationException as e:
|
|
||||||
logger.exception("Invalid session")
|
|
||||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
|
||||||
return
|
|
||||||
except MacaroonInvalidSignatureException as e:
|
|
||||||
logger.exception("Could not verify session")
|
|
||||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Exchange the code with the provider
|
# Exchange the code with the provider
|
||||||
if b"code" not in request.args:
|
|
||||||
logger.info("Code parameter is missing")
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request, "invalid_request", "Code parameter is missing"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("Exchanging code")
|
|
||||||
code = request.args[b"code"][0].decode()
|
|
||||||
try:
|
try:
|
||||||
|
logger.debug("Exchanging code")
|
||||||
token = await self._exchange_code(code)
|
token = await self._exchange_code(code)
|
||||||
except OidcError as e:
|
except OidcError as e:
|
||||||
logger.exception("Could not exchange code")
|
logger.exception("Could not exchange code")
|
||||||
|
@ -666,14 +745,14 @@ class OidcHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
logger.debug("Extracting userinfo from id_token")
|
logger.debug("Extracting userinfo from id_token")
|
||||||
try:
|
try:
|
||||||
userinfo = await self._parse_id_token(token, nonce=nonce)
|
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Invalid id_token")
|
logger.exception("Invalid id_token")
|
||||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# first check if we're doing a UIA
|
# first check if we're doing a UIA
|
||||||
if ui_auth_session_id:
|
if session_data.ui_auth_session_id:
|
||||||
try:
|
try:
|
||||||
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -682,7 +761,7 @@ class OidcHandler(BaseHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
|
self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
|
||||||
)
|
)
|
||||||
|
|
||||||
# otherwise, it's a login
|
# otherwise, it's a login
|
||||||
|
@ -690,133 +769,12 @@ class OidcHandler(BaseHandler):
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
try:
|
try:
|
||||||
await self._complete_oidc_login(
|
await self._complete_oidc_login(
|
||||||
userinfo, token, request, client_redirect_url
|
userinfo, token, request, session_data.client_redirect_url
|
||||||
)
|
)
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
|
|
||||||
def _generate_oidc_session_token(
|
|
||||||
self,
|
|
||||||
state: str,
|
|
||||||
nonce: str,
|
|
||||||
client_redirect_url: str,
|
|
||||||
ui_auth_session_id: Optional[str],
|
|
||||||
duration_in_ms: int = (60 * 60 * 1000),
|
|
||||||
) -> str:
|
|
||||||
"""Generates a signed token storing data about an OIDC session.
|
|
||||||
|
|
||||||
When Synapse initiates an authorization flow, it creates a random state
|
|
||||||
and a random nonce. Those parameters are given to the provider and
|
|
||||||
should be verified when the client comes back from the provider.
|
|
||||||
It is also used to store the client_redirect_url, which is used to
|
|
||||||
complete the SSO login flow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The ``state`` parameter passed to the OIDC provider.
|
|
||||||
nonce: The ``nonce`` parameter passed to the OIDC provider.
|
|
||||||
client_redirect_url: The URL the client gave when it initiated the
|
|
||||||
flow.
|
|
||||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
|
||||||
None if this is a login).
|
|
||||||
duration_in_ms: An optional duration for the token in milliseconds.
|
|
||||||
Defaults to an hour.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A signed macaroon token with the session information.
|
|
||||||
"""
|
|
||||||
macaroon = pymacaroons.Macaroon(
|
|
||||||
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
|
|
||||||
)
|
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
|
||||||
macaroon.add_first_party_caveat("type = session")
|
|
||||||
macaroon.add_first_party_caveat("state = %s" % (state,))
|
|
||||||
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
|
|
||||||
macaroon.add_first_party_caveat(
|
|
||||||
"client_redirect_url = %s" % (client_redirect_url,)
|
|
||||||
)
|
|
||||||
if ui_auth_session_id:
|
|
||||||
macaroon.add_first_party_caveat(
|
|
||||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
|
||||||
)
|
|
||||||
now = self.clock.time_msec()
|
|
||||||
expiry = now + duration_in_ms
|
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
|
||||||
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def _verify_oidc_session_token(
|
|
||||||
self, session: bytes, state: str
|
|
||||||
) -> Tuple[str, str, Optional[str]]:
|
|
||||||
"""Verifies and extract an OIDC session token.
|
|
||||||
|
|
||||||
This verifies that a given session token was issued by this homeserver
|
|
||||||
and extract the nonce and client_redirect_url caveats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: The session token to verify
|
|
||||||
state: The state the OIDC provider gave back
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The nonce, client_redirect_url, and ui_auth_session_id for this session
|
|
||||||
"""
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
|
||||||
v.satisfy_exact("gen = 1")
|
|
||||||
v.satisfy_exact("type = session")
|
|
||||||
v.satisfy_exact("state = %s" % (state,))
|
|
||||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
|
||||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
|
||||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
|
||||||
# to always satisfy this.
|
|
||||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
|
||||||
v.satisfy_general(self._verify_expiry)
|
|
||||||
|
|
||||||
v.verify(macaroon, self._macaroon_secret_key)
|
|
||||||
|
|
||||||
# Extract the `nonce`, `client_redirect_url`, and maybe the
|
|
||||||
# `ui_auth_session_id` from the token.
|
|
||||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
|
||||||
client_redirect_url = self._get_value_from_macaroon(
|
|
||||||
macaroon, "client_redirect_url"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
ui_auth_session_id = self._get_value_from_macaroon(
|
|
||||||
macaroon, "ui_auth_session_id"
|
|
||||||
) # type: Optional[str]
|
|
||||||
except ValueError:
|
|
||||||
ui_auth_session_id = None
|
|
||||||
|
|
||||||
return nonce, client_redirect_url, ui_auth_session_id
|
|
||||||
|
|
||||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
|
||||||
"""Extracts a caveat value from a macaroon token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
macaroon: the token
|
|
||||||
key: the key of the caveat to extract
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The extracted value
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: if the caveat was not in the macaroon
|
|
||||||
"""
|
|
||||||
prefix = key + " = "
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(prefix):
|
|
||||||
return caveat.caveat_id[len(prefix) :]
|
|
||||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
|
||||||
|
|
||||||
def _verify_expiry(self, caveat: str) -> bool:
|
|
||||||
prefix = "time < "
|
|
||||||
if not caveat.startswith(prefix):
|
|
||||||
return False
|
|
||||||
expiry = int(caveat[len(prefix) :])
|
|
||||||
now = self.clock.time_msec()
|
|
||||||
return now < expiry
|
|
||||||
|
|
||||||
async def _complete_oidc_login(
|
async def _complete_oidc_login(
|
||||||
self,
|
self,
|
||||||
userinfo: UserInfo,
|
userinfo: UserInfo,
|
||||||
|
@ -893,8 +851,8 @@ class OidcHandler(BaseHandler):
|
||||||
# and attempt to match it.
|
# and attempt to match it.
|
||||||
attributes = await oidc_response_to_user_attributes(failures=0)
|
attributes = await oidc_response_to_user_attributes(failures=0)
|
||||||
|
|
||||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
user_id = UserID(attributes.localpart, self._server_name).to_string()
|
||||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
users = await self._store.get_users_by_id_case_insensitive(user_id)
|
||||||
if users:
|
if users:
|
||||||
# If an existing matrix ID is returned, then use it.
|
# If an existing matrix ID is returned, then use it.
|
||||||
if len(users) == 1:
|
if len(users) == 1:
|
||||||
|
@ -923,7 +881,7 @@ class OidcHandler(BaseHandler):
|
||||||
extra_attributes = await get_extra_attributes(userinfo, token)
|
extra_attributes = await get_extra_attributes(userinfo, token)
|
||||||
|
|
||||||
await self._sso_handler.complete_sso_login_request(
|
await self._sso_handler.complete_sso_login_request(
|
||||||
self._auth_provider_id,
|
self.idp_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
|
@ -946,6 +904,157 @@ class OidcHandler(BaseHandler):
|
||||||
return str(remote_user_id)
|
return str(remote_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class OidcSessionTokenGenerator:
|
||||||
|
"""Methods for generating and checking OIDC Session cookies."""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
|
||||||
|
|
||||||
|
def generate_oidc_session_token(
|
||||||
|
self,
|
||||||
|
state: str,
|
||||||
|
session_data: "OidcSessionData",
|
||||||
|
duration_in_ms: int = (60 * 60 * 1000),
|
||||||
|
) -> str:
|
||||||
|
"""Generates a signed token storing data about an OIDC session.
|
||||||
|
|
||||||
|
When Synapse initiates an authorization flow, it creates a random state
|
||||||
|
and a random nonce. Those parameters are given to the provider and
|
||||||
|
should be verified when the client comes back from the provider.
|
||||||
|
It is also used to store the client_redirect_url, which is used to
|
||||||
|
complete the SSO login flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The ``state`` parameter passed to the OIDC provider.
|
||||||
|
session_data: data to include in the session token.
|
||||||
|
duration_in_ms: An optional duration for the token in milliseconds.
|
||||||
|
Defaults to an hour.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A signed macaroon token with the session information.
|
||||||
|
"""
|
||||||
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
|
||||||
|
)
|
||||||
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
|
macaroon.add_first_party_caveat("type = session")
|
||||||
|
macaroon.add_first_party_caveat("state = %s" % (state,))
|
||||||
|
macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
|
||||||
|
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
|
||||||
|
macaroon.add_first_party_caveat(
|
||||||
|
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||||
|
)
|
||||||
|
if session_data.ui_auth_session_id:
|
||||||
|
macaroon.add_first_party_caveat(
|
||||||
|
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
||||||
|
)
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
expiry = now + duration_in_ms
|
||||||
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
|
||||||
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def verify_oidc_session_token(
|
||||||
|
self, session: bytes, state: str
|
||||||
|
) -> "OidcSessionData":
|
||||||
|
"""Verifies and extract an OIDC session token.
|
||||||
|
|
||||||
|
This verifies that a given session token was issued by this homeserver
|
||||||
|
and extract the nonce and client_redirect_url caveats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: The session token to verify
|
||||||
|
state: The state the OIDC provider gave back
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The data extracted from the session cookie
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if an expected caveat is missing from the macaroon.
|
||||||
|
"""
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||||
|
|
||||||
|
v = pymacaroons.Verifier()
|
||||||
|
v.satisfy_exact("gen = 1")
|
||||||
|
v.satisfy_exact("type = session")
|
||||||
|
v.satisfy_exact("state = %s" % (state,))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||||
|
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||||
|
# to always satisfy this.
|
||||||
|
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||||
|
v.satisfy_general(self._verify_expiry)
|
||||||
|
|
||||||
|
v.verify(macaroon, self._macaroon_secret_key)
|
||||||
|
|
||||||
|
# Extract the session data from the token.
|
||||||
|
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||||
|
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
|
||||||
|
client_redirect_url = self._get_value_from_macaroon(
|
||||||
|
macaroon, "client_redirect_url"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ui_auth_session_id = self._get_value_from_macaroon(
|
||||||
|
macaroon, "ui_auth_session_id"
|
||||||
|
) # type: Optional[str]
|
||||||
|
except ValueError:
|
||||||
|
ui_auth_session_id = None
|
||||||
|
|
||||||
|
return OidcSessionData(
|
||||||
|
nonce=nonce,
|
||||||
|
idp_id=idp_id,
|
||||||
|
client_redirect_url=client_redirect_url,
|
||||||
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||||
|
"""Extracts a caveat value from a macaroon token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
macaroon: the token
|
||||||
|
key: the key of the caveat to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the caveat was not in the macaroon
|
||||||
|
"""
|
||||||
|
prefix = key + " = "
|
||||||
|
for caveat in macaroon.caveats:
|
||||||
|
if caveat.caveat_id.startswith(prefix):
|
||||||
|
return caveat.caveat_id[len(prefix) :]
|
||||||
|
raise ValueError("No %s caveat in macaroon" % (key,))
|
||||||
|
|
||||||
|
def _verify_expiry(self, caveat: str) -> bool:
|
||||||
|
prefix = "time < "
|
||||||
|
if not caveat.startswith(prefix):
|
||||||
|
return False
|
||||||
|
expiry = int(caveat[len(prefix) :])
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
return now < expiry
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True)
|
||||||
|
class OidcSessionData:
|
||||||
|
"""The attributes which are stored in a OIDC session cookie"""
|
||||||
|
|
||||||
|
# the Identity Provider being used
|
||||||
|
idp_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
# The `nonce` parameter passed to the OIDC provider.
|
||||||
|
nonce = attr.ib(type=str)
|
||||||
|
|
||||||
|
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
||||||
|
client_redirect_url = attr.ib(type=str)
|
||||||
|
|
||||||
|
# The session ID of the ongoing UI Auth (None if this is a login)
|
||||||
|
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||||
|
|
||||||
|
|
||||||
UserAttributeDict = TypedDict(
|
UserAttributeDict = TypedDict(
|
||||||
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
|
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
|
||||||
)
|
)
|
||||||
|
|
|
@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise e.to_synapse_error()
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
return result["displayname"]
|
return result.get("displayname")
|
||||||
|
|
||||||
async def set_displayname(
|
async def set_displayname(
|
||||||
self,
|
self,
|
||||||
|
@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise e.to_synapse_error()
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
return result["avatar_url"]
|
return result.get("avatar_url")
|
||||||
|
|
||||||
async def set_avatar_url(
|
async def set_avatar_url(
|
||||||
self,
|
self,
|
||||||
|
@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
|
||||||
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
avatar_url_to_set = new_avatar_url # type: Optional[str]
|
||||||
|
if new_avatar_url == "":
|
||||||
|
avatar_url_to_set = None
|
||||||
|
|
||||||
# Same like set_displayname
|
# Same like set_displayname
|
||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(
|
requester = create_requester(
|
||||||
target_user, authenticated_entity=requester.authenticated_entity
|
target_user, authenticated_entity=requester.authenticated_entity
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
await self.store.set_profile_avatar_url(
|
||||||
|
target_user.localpart, avatar_url_to_set
|
||||||
|
)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
|
|
|
@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.account_data_handler = hs.get_account_data_handler()
|
||||||
self.read_marker_linearizer = Linearizer(name="read_marker")
|
self.read_marker_linearizer = Linearizer(name="read_marker")
|
||||||
self.notifier = hs.get_notifier()
|
|
||||||
|
|
||||||
async def received_client_read_marker(
|
async def received_client_read_marker(
|
||||||
self, room_id: str, user_id: str, event_id: str
|
self, room_id: str, user_id: str, event_id: str
|
||||||
|
@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
|
||||||
|
|
||||||
if should_update:
|
if should_update:
|
||||||
content = {"event_id": event_id}
|
content = {"event_id": event_id}
|
||||||
max_id = await self.store.add_account_data_to_room(
|
await self.account_data_handler.add_account_data_to_room(
|
||||||
user_id, room_id, "m.fully_read", content
|
user_id, room_id, "m.fully_read", content
|
||||||
)
|
)
|
||||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
|
||||||
|
|
|
@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
|
||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.federation = hs.get_federation_sender()
|
|
||||||
|
# We only need to poke the federation sender explicitly if its on the
|
||||||
|
# same instance. Other federation sender instances will get notified by
|
||||||
|
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
|
||||||
|
# in the receipts stream.
|
||||||
|
self.federation_sender = None
|
||||||
|
if hs.should_send_federation():
|
||||||
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
|
||||||
|
# If we can handle the receipt EDUs we do so, otherwise we route them
|
||||||
|
# to the appropriate worker.
|
||||||
|
if hs.get_instance_name() in hs.config.worker.writers.receipts:
|
||||||
hs.get_federation_registry().register_edu_handler(
|
hs.get_federation_registry().register_edu_handler(
|
||||||
"m.receipt", self._received_remote_receipt
|
"m.receipt", self._received_remote_receipt
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
hs.get_federation_registry().register_instances_for_edu(
|
||||||
|
"m.receipt", hs.config.worker.writers.receipts,
|
||||||
|
)
|
||||||
|
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
|
@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
|
||||||
if not is_new:
|
if not is_new:
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.federation.send_read_receipt(receipt)
|
if self.federation_sender:
|
||||||
|
await self.federation_sender.send_read_receipt(receipt)
|
||||||
|
|
||||||
|
|
||||||
class ReceiptEventSource:
|
class ReceiptEventSource:
|
||||||
|
|
|
@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import copy_power_levels_contents
|
from synapse.events.utils import copy_power_levels_contents
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
|
@ -55,6 +54,7 @@ from synapse.types import (
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
@ -365,7 +365,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
creation_content = {
|
creation_content = {
|
||||||
"room_version": new_room_version.identifier,
|
"room_version": new_room_version.identifier,
|
||||||
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
|
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
|
||||||
}
|
} # type: JsonDict
|
||||||
|
|
||||||
# Check if old room was non-federatable
|
# Check if old room was non-federatable
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
self.account_data_handler = hs.get_account_data_handler()
|
||||||
|
|
||||||
self.member_linearizer = Linearizer(name="member")
|
self.member_linearizer = Linearizer(name="member")
|
||||||
|
|
||||||
|
@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
direct_rooms[key].append(new_room_id)
|
direct_rooms[key].append(new_room_id)
|
||||||
|
|
||||||
# Save back to user's m.direct account data
|
# Save back to user's m.direct account data
|
||||||
await self.store.add_account_data_for_user(
|
await self.account_data_handler.add_account_data_for_user(
|
||||||
user_id, AccountDataTypes.DIRECT, direct_rooms
|
user_id, AccountDataTypes.DIRECT, direct_rooms
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
# Copy each room tag to the new room
|
# Copy each room tag to the new room
|
||||||
for tag, tag_content in room_tags.items():
|
for tag, tag_content in room_tags.items():
|
||||||
await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
|
await self.account_data_handler.add_tag_to_room(
|
||||||
|
user_id, new_room_id, tag, tag_content
|
||||||
|
)
|
||||||
|
|
||||||
async def update_membership(
|
async def update_membership(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -73,27 +73,45 @@ class SamlHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# identifier for the external_ids table
|
# identifier for the external_ids table
|
||||||
self._auth_provider_id = "saml"
|
self.idp_id = "saml"
|
||||||
|
|
||||||
|
# user-facing name of this auth provider
|
||||||
|
self.idp_name = "SAML"
|
||||||
|
|
||||||
|
# we do not currently support icons for SAML auth, but this is required by
|
||||||
|
# the SsoIdentityProvider protocol type.
|
||||||
|
self.idp_icon = None
|
||||||
|
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
||||||
def handle_redirect_request(
|
async def handle_redirect_request(
|
||||||
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
|
self,
|
||||||
) -> bytes:
|
request: SynapseRequest,
|
||||||
|
client_redirect_url: Optional[bytes],
|
||||||
|
ui_auth_session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""Handle an incoming request to /login/sso/redirect
|
"""Handle an incoming request to /login/sso/redirect
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
request: the incoming HTTP request
|
||||||
client_redirect_url: the URL that we should redirect the
|
client_redirect_url: the URL that we should redirect the
|
||||||
client to when everything is done
|
client to after login (or None for UI Auth).
|
||||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||||
None if this is a login).
|
None if this is a login).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
URL to redirect to
|
URL to redirect to
|
||||||
"""
|
"""
|
||||||
|
if not client_redirect_url:
|
||||||
|
# Some SAML identity providers (e.g. Google) require a
|
||||||
|
# RelayState parameter on requests, so pass in a dummy redirect URL
|
||||||
|
# (which will never get used).
|
||||||
|
client_redirect_url = b"unused"
|
||||||
|
|
||||||
reqid, info = self._saml_client.prepare_for_authenticate(
|
reqid, info = self._saml_client.prepare_for_authenticate(
|
||||||
entityid=self._saml_idp_entityid, relay_state=client_redirect_url
|
entityid=self._saml_idp_entityid, relay_state=client_redirect_url
|
||||||
)
|
)
|
||||||
|
@ -210,7 +228,7 @@ class SamlHandler(BaseHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
self._auth_provider_id,
|
self.idp_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
current_session.ui_auth_session_id,
|
current_session.ui_auth_session_id,
|
||||||
request,
|
request,
|
||||||
|
@ -306,7 +324,7 @@ class SamlHandler(BaseHandler):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await self._sso_handler.complete_sso_login_request(
|
await self._sso_handler.complete_sso_login_request(
|
||||||
self._auth_provider_id,
|
self.idp_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
|
|
|
@ -12,15 +12,20 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import NoReturn
|
from typing_extensions import NoReturn, Protocol
|
||||||
|
|
||||||
from twisted.web.http import Request
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException, SynapseError
|
from synapse.api.constants import LoginType
|
||||||
|
from synapse.api.errors import Codes, RedirectException, SynapseError
|
||||||
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||||
|
from synapse.http import get_request_user_agent
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
|
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
|
||||||
|
@ -40,6 +45,63 @@ class MappingException(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SsoIdentityProvider(Protocol):
|
||||||
|
"""Abstract base class to be implemented by SSO Identity Providers
|
||||||
|
|
||||||
|
An Identity Provider, or IdP, is an external HTTP service which authenticates a user
|
||||||
|
to say whether they should be allowed to log in, or perform a given action.
|
||||||
|
|
||||||
|
Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
|
||||||
|
and CAS.
|
||||||
|
|
||||||
|
The main entry point is `handle_redirect_request`, which should return a URI to
|
||||||
|
redirect the user's browser to the IdP's authentication page.
|
||||||
|
|
||||||
|
Each IdP should be registered with the SsoHandler via
|
||||||
|
`hs.get_sso_handler().register_identity_provider()`, so that requests to
|
||||||
|
`/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def idp_id(self) -> str:
|
||||||
|
"""A unique identifier for this SSO provider
|
||||||
|
|
||||||
|
Eg, "saml", "cas", "github"
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def idp_name(self) -> str:
|
||||||
|
"""User-facing name for this provider"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def idp_icon(self) -> Optional[str]:
|
||||||
|
"""Optional MXC URI for user-facing icon"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def handle_redirect_request(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
client_redirect_url: Optional[bytes],
|
||||||
|
ui_auth_session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Handle an incoming request to /login/sso/redirect
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: the incoming HTTP request
|
||||||
|
client_redirect_url: the URL that we should redirect the
|
||||||
|
client to after login (or None for UI Auth).
|
||||||
|
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||||
|
None if this is a login).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL to redirect to
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class UserAttributes:
|
class UserAttributes:
|
||||||
# the localpart of the mxid that the mapper has assigned to the user.
|
# the localpart of the mxid that the mapper has assigned to the user.
|
||||||
|
@ -91,8 +153,13 @@ class SsoHandler:
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._error_template = hs.config.sso_error_template
|
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._error_template = hs.config.sso_error_template
|
||||||
|
self._bad_user_template = hs.config.sso_auth_bad_user_template
|
||||||
|
|
||||||
|
# The following template is shown after a successful user interactive
|
||||||
|
# authentication session. It tells the user they can close the window.
|
||||||
|
self._sso_auth_success_template = hs.config.sso_auth_success_template
|
||||||
|
|
||||||
# a lock on the mappings
|
# a lock on the mappings
|
||||||
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
||||||
|
@ -100,6 +167,49 @@ class SsoHandler:
|
||||||
# a map from session id to session data
|
# a map from session id to session data
|
||||||
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
||||||
|
|
||||||
|
# map from idp_id to SsoIdentityProvider
|
||||||
|
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
|
||||||
|
|
||||||
|
def register_identity_provider(self, p: SsoIdentityProvider):
|
||||||
|
p_id = p.idp_id
|
||||||
|
assert p_id not in self._identity_providers
|
||||||
|
self._identity_providers[p_id] = p
|
||||||
|
|
||||||
|
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
|
||||||
|
"""Get the configured identity providers"""
|
||||||
|
return self._identity_providers
|
||||||
|
|
||||||
|
async def get_identity_providers_for_user(
|
||||||
|
self, user_id: str
|
||||||
|
) -> Mapping[str, SsoIdentityProvider]:
|
||||||
|
"""Get the SsoIdentityProviders which a user has used
|
||||||
|
|
||||||
|
Given a user id, get the identity providers that that user has used to log in
|
||||||
|
with in the past (and thus could use to re-identify themselves for UI Auth).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: MXID of user to look up
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
a map of idp_id to SsoIdentityProvider
|
||||||
|
"""
|
||||||
|
external_ids = await self._store.get_external_ids_by_user(user_id)
|
||||||
|
|
||||||
|
valid_idps = {}
|
||||||
|
for idp_id, _ in external_ids:
|
||||||
|
idp = self._identity_providers.get(idp_id)
|
||||||
|
if not idp:
|
||||||
|
logger.warning(
|
||||||
|
"User %r has an SSO mapping for IdP %r, but this is no longer "
|
||||||
|
"configured.",
|
||||||
|
user_id,
|
||||||
|
idp_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
valid_idps[idp_id] = idp
|
||||||
|
|
||||||
|
return valid_idps
|
||||||
|
|
||||||
def render_error(
|
def render_error(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -124,6 +234,34 @@ class SsoHandler:
|
||||||
)
|
)
|
||||||
respond_with_html(request, code, html)
|
respond_with_html(request, code, html)
|
||||||
|
|
||||||
|
async def handle_redirect_request(
|
||||||
|
self, request: SynapseRequest, client_redirect_url: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Handle a request to /login/sso/redirect
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: incoming HTTP request
|
||||||
|
client_redirect_url: the URL that we should redirect the
|
||||||
|
client to after login.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the URI to redirect to
|
||||||
|
"""
|
||||||
|
if not self._identity_providers:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
# if we only have one auth provider, redirect to it directly
|
||||||
|
if len(self._identity_providers) == 1:
|
||||||
|
ap = next(iter(self._identity_providers.values()))
|
||||||
|
return await ap.handle_redirect_request(request, client_redirect_url)
|
||||||
|
|
||||||
|
# otherwise, redirect to the IDP picker
|
||||||
|
return "/_synapse/client/pick_idp?" + urlencode(
|
||||||
|
(("redirectUrl", client_redirect_url),)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_sso_user_by_remote_user_id(
|
async def get_sso_user_by_remote_user_id(
|
||||||
self, auth_provider_id: str, remote_user_id: str
|
self, auth_provider_id: str, remote_user_id: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
@ -268,7 +406,7 @@ class SsoHandler:
|
||||||
attributes,
|
attributes,
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
request.get_user_agent(""),
|
get_request_user_agent(request),
|
||||||
request.getClientIP(),
|
request.getClientIP(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -451,19 +589,45 @@ class SsoHandler:
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id, remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_id_to_verify = await self._auth_handler.get_session_data(
|
||||||
|
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
||||||
|
) # type: str
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Remote user %s/%s has not previously logged in here: UIA will fail",
|
"Remote user %s/%s has not previously logged in here: UIA will fail",
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
)
|
)
|
||||||
# Let the UIA flow handle this the same as if they presented creds for a
|
elif user_id != user_id_to_verify:
|
||||||
# different user.
|
logger.warning(
|
||||||
user_id = ""
|
"Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
|
||||||
|
auth_provider_id,
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
remote_user_id,
|
||||||
user_id, ui_auth_session_id, request
|
user_id,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# success!
|
||||||
|
# Mark the stage of the authentication as successful.
|
||||||
|
await self._store.mark_ui_auth_stage_complete(
|
||||||
|
ui_auth_session_id, LoginType.SSO, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Render the HTML confirmation page and return.
|
||||||
|
html = self._sso_auth_success_template
|
||||||
|
respond_with_html(request, 200, html)
|
||||||
|
return
|
||||||
|
|
||||||
|
# the user_id didn't match: mark the stage of the authentication as unsuccessful
|
||||||
|
await self._store.mark_ui_auth_stage_complete(
|
||||||
|
ui_auth_session_id, LoginType.SSO, ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# render an error page.
|
||||||
|
html = self._bad_user_template.render(
|
||||||
|
server_name=self._server_name, user_id_to_verify=user_id_to_verify,
|
||||||
|
)
|
||||||
|
respond_with_html(request, 200, html)
|
||||||
|
|
||||||
async def check_username_availability(
|
async def check_username_availability(
|
||||||
self, localpart: str, session_id: str,
|
self, localpart: str, session_id: str,
|
||||||
|
@ -534,7 +698,7 @@ class SsoHandler:
|
||||||
attributes,
|
attributes,
|
||||||
session.auth_provider_id,
|
session.auth_provider_id,
|
||||||
session.remote_user_id,
|
session.remote_user_id,
|
||||||
request.get_user_agent(""),
|
get_request_user_agent(request),
|
||||||
request.getClientIP(),
|
request.getClientIP(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
|
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
class UIAuthSessionDataConstants:
|
||||||
|
"""Constants for use with AuthHandler.set_session_data"""
|
||||||
|
|
||||||
|
# used during registration and password reset to store a hashed copy of the
|
||||||
|
# password, so that the client does not need to submit it each time.
|
||||||
|
PASSWORD_HASH = "password_hash"
|
||||||
|
|
||||||
|
# used during registration to store the mxid of the registered user
|
||||||
|
REGISTERED_USER_ID = "registered_user_id"
|
||||||
|
|
||||||
|
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
|
||||||
|
# for.
|
||||||
|
REQUEST_USER_ID = "request_user_id"
|
||||||
|
|
|
@ -17,6 +17,7 @@ import re
|
||||||
|
|
||||||
from twisted.internet import task
|
from twisted.internet import task
|
||||||
from twisted.web.client import FileBodyProducer
|
from twisted.web.client import FileBodyProducer
|
||||||
|
from twisted.web.iweb import IRequest
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
|
||||||
FileBodyProducer.stopProducing(self)
|
FileBodyProducer.stopProducing(self)
|
||||||
except task.TaskStopped:
|
except task.TaskStopped:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_user_agent(request: IRequest, default: str = "") -> str:
|
||||||
|
"""Return the last User-Agent header, or the given default.
|
||||||
|
"""
|
||||||
|
# There could be raw utf-8 bytes in the User-Agent header.
|
||||||
|
|
||||||
|
# N.B. if you don't do this, the logger explodes cryptically
|
||||||
|
# with maximum recursion trying to log errors about
|
||||||
|
# the charset problem.
|
||||||
|
# c.f. https://github.com/matrix-org/synapse/issues/3471
|
||||||
|
|
||||||
|
h = request.getHeader(b"User-Agent")
|
||||||
|
return h.decode("ascii", "replace") if h else default
|
||||||
|
|
|
@ -32,7 +32,7 @@ from typing import (
|
||||||
|
|
||||||
import treq
|
import treq
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from netaddr import IPAddress, IPSet
|
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from zope.interface import implementer, provider
|
from zope.interface import implementer, provider
|
||||||
|
|
||||||
|
@ -261,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ip_address = IPAddress(h.hostname)
|
ip_address = IPAddress(h.hostname)
|
||||||
|
except AddrFormatError:
|
||||||
|
# Not an IP
|
||||||
|
pass
|
||||||
|
else:
|
||||||
if check_against_blacklist(
|
if check_against_blacklist(
|
||||||
ip_address, self._ip_whitelist, self._ip_blacklist
|
ip_address, self._ip_whitelist, self._ip_blacklist
|
||||||
):
|
):
|
||||||
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
|
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
|
||||||
e = SynapseError(403, "IP address blocked by IP blacklist entry")
|
e = SynapseError(403, "IP address blocked by IP blacklist entry")
|
||||||
return defer.fail(Failure(e))
|
return defer.fail(Failure(e))
|
||||||
except Exception:
|
|
||||||
# Not an IP
|
|
||||||
pass
|
|
||||||
|
|
||||||
return self._agent.request(
|
return self._agent.request(
|
||||||
method, uri, headers=headers, bodyProducer=bodyProducer
|
method, uri, headers=headers, bodyProducer=bodyProducer
|
||||||
|
@ -725,7 +725,7 @@ class SimpleHttpClient:
|
||||||
read_body_with_max_size(response, output_stream, max_size)
|
read_body_with_max_size(response, output_stream, max_size)
|
||||||
)
|
)
|
||||||
except BodyExceededMaxSize:
|
except BodyExceededMaxSize:
|
||||||
SynapseError(
|
raise SynapseError(
|
||||||
502,
|
502,
|
||||||
"Requested file is too large > %r bytes" % (max_size,),
|
"Requested file is too large > %r bytes" % (max_size,),
|
||||||
Codes.TOO_LARGE,
|
Codes.TOO_LARGE,
|
||||||
|
@ -767,14 +767,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
|
|
||||||
def dataReceived(self, data: bytes) -> None:
|
def dataReceived(self, data: bytes) -> None:
|
||||||
|
# If the deferred was called, bail early.
|
||||||
|
if self.deferred.called:
|
||||||
|
return
|
||||||
|
|
||||||
self.stream.write(data)
|
self.stream.write(data)
|
||||||
self.length += len(data)
|
self.length += len(data)
|
||||||
|
# The first time the maximum size is exceeded, error and cancel the
|
||||||
|
# connection. dataReceived might be called again if data was received
|
||||||
|
# in the meantime.
|
||||||
if self.max_size is not None and self.length >= self.max_size:
|
if self.max_size is not None and self.length >= self.max_size:
|
||||||
self.deferred.errback(BodyExceededMaxSize())
|
self.deferred.errback(BodyExceededMaxSize())
|
||||||
self.deferred = defer.Deferred()
|
|
||||||
self.transport.loseConnection()
|
self.transport.loseConnection()
|
||||||
|
|
||||||
def connectionLost(self, reason: Failure) -> None:
|
def connectionLost(self, reason: Failure) -> None:
|
||||||
|
# If the maximum size was already exceeded, there's nothing to do.
|
||||||
|
if self.deferred.called:
|
||||||
|
return
|
||||||
|
|
||||||
if reason.check(ResponseDone):
|
if reason.check(ResponseDone):
|
||||||
self.deferred.callback(self.length)
|
self.deferred.callback(self.length)
|
||||||
elif reason.check(PotentialDataLoss):
|
elif reason.check(PotentialDataLoss):
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_server_name(server_name):
|
|
||||||
"""Split a server name into host/port parts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_name (str): server name to parse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, int|None]: host/port parts.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError if the server name could not be parsed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if server_name[-1] == "]":
|
|
||||||
# ipv6 literal, hopefully
|
|
||||||
return server_name, None
|
|
||||||
|
|
||||||
domain_port = server_name.rsplit(":", 1)
|
|
||||||
domain = domain_port[0]
|
|
||||||
port = int(domain_port[1]) if domain_port[1:] else None
|
|
||||||
return domain, port
|
|
||||||
except Exception:
|
|
||||||
raise ValueError("Invalid server name '%s'" % server_name)
|
|
||||||
|
|
||||||
|
|
||||||
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_and_validate_server_name(server_name):
|
|
||||||
"""Split a server name into host/port parts and do some basic validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_name (str): server name to parse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, int|None]: host/port parts.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError if the server name could not be parsed.
|
|
||||||
"""
|
|
||||||
host, port = parse_server_name(server_name)
|
|
||||||
|
|
||||||
# these tests don't need to be bulletproof as we'll find out soon enough
|
|
||||||
# if somebody is giving us invalid data. What we *do* need is to be sure
|
|
||||||
# that nobody is sneaking IP literals in that look like hostnames, etc.
|
|
||||||
|
|
||||||
# look for ipv6 literals
|
|
||||||
if host[0] == "[":
|
|
||||||
if host[-1] != "]":
|
|
||||||
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|
|
||||||
return host, port
|
|
||||||
|
|
||||||
# otherwise it should only be alphanumerics.
|
|
||||||
if not VALID_HOST_REGEX.match(host):
|
|
||||||
raise ValueError(
|
|
||||||
"Server name '%s' contains invalid characters" % (server_name,)
|
|
||||||
)
|
|
||||||
|
|
||||||
return host, port
|
|
|
@ -102,7 +102,6 @@ class MatrixFederationAgent:
|
||||||
pool=self._pool,
|
pool=self._pool,
|
||||||
contextFactory=tls_client_options_factory,
|
contextFactory=tls_client_options_factory,
|
||||||
),
|
),
|
||||||
self._reactor,
|
|
||||||
ip_blacklist=ip_blacklist,
|
ip_blacklist=ip_blacklist,
|
||||||
),
|
),
|
||||||
user_agent=self.user_agent,
|
user_agent=self.user_agent,
|
||||||
|
|
|
@ -174,6 +174,16 @@ async def _handle_json_response(
|
||||||
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
||||||
|
|
||||||
body = await make_deferred_yieldable(d)
|
body = await make_deferred_yieldable(d)
|
||||||
|
except ValueError as e:
|
||||||
|
# The JSON content was invalid.
|
||||||
|
logger.warning(
|
||||||
|
"{%s} [%s] Failed to parse JSON response - %s %s",
|
||||||
|
request.txn_id,
|
||||||
|
request.destination,
|
||||||
|
request.method,
|
||||||
|
request.uri.decode("ascii"),
|
||||||
|
)
|
||||||
|
raise RequestSendFailed(e, can_retry=False) from e
|
||||||
except defer.TimeoutError as e:
|
except defer.TimeoutError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] Timed out reading response - %s %s",
|
"{%s} [%s] Timed out reading response - %s %s",
|
||||||
|
@ -986,7 +996,7 @@ class MatrixFederationHttpClient:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] %s", request.txn_id, request.destination, msg,
|
"{%s} [%s] %s", request.txn_id, request.destination, msg,
|
||||||
)
|
)
|
||||||
SynapseError(502, msg, Codes.TOO_LARGE)
|
raise SynapseError(502, msg, Codes.TOO_LARGE)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] Error reading response: %s",
|
"{%s} [%s] Error reading response: %s",
|
||||||
|
|
|
@ -20,7 +20,7 @@ from twisted.python.failure import Failure
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.config.server import ListenerConfig
|
from synapse.config.server import ListenerConfig
|
||||||
from synapse.http import redact_uri
|
from synapse.http import get_request_user_agent, redact_uri
|
||||||
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
||||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester
|
||||||
|
@ -113,15 +113,6 @@ class SynapseRequest(Request):
|
||||||
method = self.method.decode("ascii")
|
method = self.method.decode("ascii")
|
||||||
return method
|
return method
|
||||||
|
|
||||||
def get_user_agent(self, default: str) -> str:
|
|
||||||
"""Return the last User-Agent header, or the given default.
|
|
||||||
"""
|
|
||||||
user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
|
|
||||||
if user_agent is None:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return user_agent.decode("ascii", "replace")
|
|
||||||
|
|
||||||
def render(self, resrc):
|
def render(self, resrc):
|
||||||
# this is called once a Resource has been found to serve the request; in our
|
# this is called once a Resource has been found to serve the request; in our
|
||||||
# case the Resource in question will normally be a JsonResource.
|
# case the Resource in question will normally be a JsonResource.
|
||||||
|
@ -292,12 +283,7 @@ class SynapseRequest(Request):
|
||||||
# and can see that we're doing something wrong.
|
# and can see that we're doing something wrong.
|
||||||
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
|
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
|
||||||
|
|
||||||
# ...or could be raw utf-8 bytes in the User-Agent header.
|
user_agent = get_request_user_agent(self, "-")
|
||||||
# N.B. if you don't do this, the logger explodes cryptically
|
|
||||||
# with maximum recursion trying to log errors about
|
|
||||||
# the charset problem.
|
|
||||||
# c.f. https://github.com/matrix-org/synapse/issues/3471
|
|
||||||
user_agent = self.get_user_agent("-")
|
|
||||||
|
|
||||||
code = str(self.code)
|
code = str(self.code)
|
||||||
if not self.finished:
|
if not self.finished:
|
||||||
|
|
|
@ -252,7 +252,12 @@ class LoggingContext:
|
||||||
"scope",
|
"scope",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, name=None, parent_context=None, request=None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
parent_context: "Optional[LoggingContext]" = None,
|
||||||
|
request: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
self.previous_context = current_context()
|
self.previous_context = current_context()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter):
|
||||||
def __init__(self, request: str = ""):
|
def __init__(self, request: str = ""):
|
||||||
self._default_request = request
|
self._default_request = request
|
||||||
|
|
||||||
def filter(self, record) -> Literal[True]:
|
def filter(self, record: logging.LogRecord) -> Literal[True]:
|
||||||
"""Add each fields from the logging contexts to the record.
|
"""Add each fields from the logging contexts to the record.
|
||||||
Returns:
|
Returns:
|
||||||
True to include the record in the log output.
|
True to include the record in the log output.
|
||||||
"""
|
"""
|
||||||
context = current_context()
|
context = current_context()
|
||||||
record.request = self._default_request
|
record.request = self._default_request # type: ignore
|
||||||
|
|
||||||
# context should never be None, but if it somehow ends up being, then
|
# context should never be None, but if it somehow ends up being, then
|
||||||
# we end up in a death spiral of infinite loops, so let's check, for
|
# we end up in a death spiral of infinite loops, so let's check, for
|
||||||
# robustness' sake.
|
# robustness' sake.
|
||||||
if context is not None:
|
if context is not None:
|
||||||
# Logging is interested in the request.
|
# Logging is interested in the request.
|
||||||
record.request = context.request
|
record.request = context.request # type: ignore
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
|
||||||
def nested_logging_context(
|
def nested_logging_context(suffix: str) -> LoggingContext:
|
||||||
suffix: str, parent_context: Optional[LoggingContext] = None
|
|
||||||
) -> LoggingContext:
|
|
||||||
"""Creates a new logging context as a child of another.
|
"""Creates a new logging context as a child of another.
|
||||||
|
|
||||||
The nested logging context will have a 'request' made up of the parent context's
|
The nested logging context will have a 'request' made up of the parent context's
|
||||||
|
@ -632,20 +635,23 @@ def nested_logging_context(
|
||||||
# ... do stuff
|
# ... do stuff
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
suffix (str): suffix to add to the parent context's 'request'.
|
suffix: suffix to add to the parent context's 'request'.
|
||||||
parent_context (LoggingContext|None): parent context. Will use the current context
|
|
||||||
if None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LoggingContext: new logging context.
|
LoggingContext: new logging context.
|
||||||
"""
|
"""
|
||||||
if parent_context is not None:
|
curr_context = current_context()
|
||||||
context = parent_context # type: LoggingContextOrSentinel
|
if not curr_context:
|
||||||
else:
|
logger.warning(
|
||||||
context = current_context()
|
"Starting nested logging context from sentinel context: metrics will be lost"
|
||||||
return LoggingContext(
|
|
||||||
parent_context=context, request=str(context.request) + "-" + suffix
|
|
||||||
)
|
)
|
||||||
|
parent_context = None
|
||||||
|
prefix = ""
|
||||||
|
else:
|
||||||
|
assert isinstance(curr_context, LoggingContext)
|
||||||
|
parent_context = curr_context
|
||||||
|
prefix = str(parent_context.request)
|
||||||
|
return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
|
||||||
|
|
||||||
|
|
||||||
def preserve_fn(f):
|
def preserve_fn(f):
|
||||||
|
@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
|
||||||
Deferred: A Deferred which fires a callback with the result of `f`, or an
|
Deferred: A Deferred which fires a callback with the result of `f`, or an
|
||||||
errback if `f` throws an exception.
|
errback if `f` throws an exception.
|
||||||
"""
|
"""
|
||||||
logcontext = current_context()
|
curr_context = current_context()
|
||||||
|
if not curr_context:
|
||||||
|
logger.warning(
|
||||||
|
"Calling defer_to_threadpool from sentinel context: metrics will be lost"
|
||||||
|
)
|
||||||
|
parent_context = None
|
||||||
|
else:
|
||||||
|
assert isinstance(curr_context, LoggingContext)
|
||||||
|
parent_context = curr_context
|
||||||
|
|
||||||
def g():
|
def g():
|
||||||
with LoggingContext(parent_context=logcontext):
|
with LoggingContext(parent_context=parent_context):
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
|
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
|
||||||
|
|
|
@ -396,7 +396,6 @@ class Notifier:
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
with PreserveLoggingContext():
|
|
||||||
with Measure(self.clock, "on_new_event"):
|
with Measure(self.clock, "on_new_event"):
|
||||||
user_streams = set()
|
user_streams = set()
|
||||||
|
|
||||||
|
|
|
@ -203,13 +203,17 @@ class BulkPushRuleEvaluator:
|
||||||
|
|
||||||
condition_cache = {} # type: Dict[str, bool]
|
condition_cache = {} # type: Dict[str, bool]
|
||||||
|
|
||||||
|
# If the event is not a state event check if any users ignore the sender.
|
||||||
|
if not event.is_state():
|
||||||
|
ignorers = await self.store.ignored_by(event.sender)
|
||||||
|
else:
|
||||||
|
ignorers = set()
|
||||||
|
|
||||||
for uid, rules in rules_by_user.items():
|
for uid, rules in rules_by_user.items():
|
||||||
if event.sender == uid:
|
if event.sender == uid:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not event.is_state():
|
if uid in ignorers:
|
||||||
is_ignored = await self.store.is_ignored_by(event.sender, uid)
|
|
||||||
if is_ignored:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
display_name = None
|
display_name = None
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.replication.http import (
|
from synapse.replication.http import (
|
||||||
|
account_data,
|
||||||
devices,
|
devices,
|
||||||
federation,
|
federation,
|
||||||
login,
|
login,
|
||||||
|
@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
|
||||||
presence.register_servlets(hs, self)
|
presence.register_servlets(hs, self)
|
||||||
membership.register_servlets(hs, self)
|
membership.register_servlets(hs, self)
|
||||||
streams.register_servlets(hs, self)
|
streams.register_servlets(hs, self)
|
||||||
|
account_data.register_servlets(hs, self)
|
||||||
|
|
||||||
# The following can't currently be instantiated on workers.
|
# The following can't currently be instantiated on workers.
|
||||||
if hs.config.worker.worker_app is None:
|
if hs.config.worker.worker_app is None:
|
||||||
|
|
|
@ -177,7 +177,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@trace(opname="outgoing_replication_request")
|
@trace(opname="outgoing_replication_request")
|
||||||
@outgoing_gauge.track_inprogress()
|
@outgoing_gauge.track_inprogress()
|
||||||
async def send_request(instance_name="master", **kwargs):
|
async def send_request(*, instance_name="master", **kwargs):
|
||||||
if instance_name == local_instance_name:
|
if instance_name == local_instance_name:
|
||||||
raise Exception("Trying to send HTTP request to self")
|
raise Exception("Trying to send HTTP request to self")
|
||||||
if instance_name == "master":
|
if instance_name == "master":
|
||||||
|
|
187
synapse/replication/http/account_data.py
Normal file
187
synapse/replication/http/account_data.py
Normal file
|
@ -0,0 +1,187 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
|
||||||
|
"""Add user account data on the appropriate account data worker.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/add_user_account_data/:user_id/:type
|
||||||
|
|
||||||
|
{
|
||||||
|
"content": { ... },
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "add_user_account_data"
|
||||||
|
PATH_ARGS = ("user_id", "account_data_type")
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self.handler = hs.get_account_data_handler()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _serialize_payload(user_id, account_data_type, content):
|
||||||
|
payload = {
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def _handle_request(self, request, user_id, account_data_type):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
max_stream_id = await self.handler.add_account_data_for_user(
|
||||||
|
user_id, account_data_type, content["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, {"max_stream_id": max_stream_id}
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
|
||||||
|
"""Add room account data on the appropriate account data worker.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
|
||||||
|
|
||||||
|
{
|
||||||
|
"content": { ... },
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "add_room_account_data"
|
||||||
|
PATH_ARGS = ("user_id", "room_id", "account_data_type")
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self.handler = hs.get_account_data_handler()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _serialize_payload(user_id, room_id, account_data_type, content):
|
||||||
|
payload = {
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def _handle_request(self, request, user_id, room_id, account_data_type):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
max_stream_id = await self.handler.add_account_data_to_room(
|
||||||
|
user_id, room_id, account_data_type, content["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, {"max_stream_id": max_stream_id}
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationAddTagRestServlet(ReplicationEndpoint):
|
||||||
|
"""Add tag on the appropriate account data worker.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
|
||||||
|
|
||||||
|
{
|
||||||
|
"content": { ... },
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "add_tag"
|
||||||
|
PATH_ARGS = ("user_id", "room_id", "tag")
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self.handler = hs.get_account_data_handler()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _serialize_payload(user_id, room_id, tag, content):
|
||||||
|
payload = {
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def _handle_request(self, request, user_id, room_id, tag):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
max_stream_id = await self.handler.add_tag_to_room(
|
||||||
|
user_id, room_id, tag, content["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, {"max_stream_id": max_stream_id}
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
||||||
|
"""Remove tag on the appropriate account data worker.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
|
||||||
|
|
||||||
|
{}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "remove_tag"
|
||||||
|
PATH_ARGS = (
|
||||||
|
"user_id",
|
||||||
|
"room_id",
|
||||||
|
"tag",
|
||||||
|
)
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self.handler = hs.get_account_data_handler()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _serialize_payload(user_id, room_id, tag):
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _handle_request(self, request, user_id, room_id, tag):
|
||||||
|
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
|
||||||
|
|
||||||
|
return 200, {"max_stream_id": max_stream_id}
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ReplicationUserAccountDataRestServlet(hs).register(http_server)
|
||||||
|
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
|
||||||
|
ReplicationAddTagRestServlet(hs).register(http_server)
|
||||||
|
ReplicationRemoveTagRestServlet(hs).register(http_server)
|
|
@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||||
database,
|
database,
|
||||||
stream_name="caches",
|
stream_name="caches",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
table="cache_invalidation_stream_by_instance",
|
tables=[
|
||||||
instance_column="instance_name",
|
(
|
||||||
id_column="stream_id",
|
"cache_invalidation_stream_by_instance",
|
||||||
|
"instance_name",
|
||||||
|
"stream_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
sequence_name="cache_invalidation_stream_seq",
|
sequence_name="cache_invalidation_stream_seq",
|
||||||
writers=[],
|
writers=[],
|
||||||
) # type: Optional[MultiWriterIdGenerator]
|
) # type: Optional[MultiWriterIdGenerator]
|
||||||
|
|
|
@ -15,47 +15,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|
||||||
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
|
|
||||||
from synapse.storage.database import DatabasePool
|
|
||||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||||
from synapse.storage.databases.main.tags import TagsWorkerStore
|
from synapse.storage.databases.main.tags import TagsWorkerStore
|
||||||
|
|
||||||
|
|
||||||
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
|
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
pass
|
||||||
self._account_data_id_gen = SlavedIdTracker(
|
|
||||||
db_conn,
|
|
||||||
"account_data",
|
|
||||||
"stream_id",
|
|
||||||
extra_tables=[
|
|
||||||
("room_account_data", "stream_id"),
|
|
||||||
("room_tags_revisions", "stream_id"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(database, db_conn, hs)
|
|
||||||
|
|
||||||
def get_max_account_data_stream_id(self):
|
|
||||||
return self._account_data_id_gen.get_current_token()
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
|
||||||
if stream_name == TagAccountDataStream.NAME:
|
|
||||||
self._account_data_id_gen.advance(instance_name, token)
|
|
||||||
for row in rows:
|
|
||||||
self.get_tags_for_user.invalidate((row.user_id,))
|
|
||||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
|
||||||
elif stream_name == AccountDataStream.NAME:
|
|
||||||
self._account_data_id_gen.advance(instance_name, token)
|
|
||||||
for row in rows:
|
|
||||||
if not row.room_id:
|
|
||||||
self.get_global_account_data_by_type_for_user.invalidate(
|
|
||||||
(row.data_type, row.user_id)
|
|
||||||
)
|
|
||||||
self.get_account_data_for_user.invalidate((row.user_id,))
|
|
||||||
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
|
|
||||||
self.get_account_data_for_room_and_type.invalidate(
|
|
||||||
(row.user_id, row.room_id, row.data_type)
|
|
||||||
)
|
|
||||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
|
||||||
|
|
|
@ -14,46 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|
||||||
from synapse.replication.tcp.streams import ToDeviceStream
|
|
||||||
from synapse.storage.database import DatabasePool
|
|
||||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
|
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
pass
|
||||||
super().__init__(database, db_conn, hs)
|
|
||||||
self._device_inbox_id_gen = SlavedIdTracker(
|
|
||||||
db_conn, "device_inbox", "stream_id"
|
|
||||||
)
|
|
||||||
self._device_inbox_stream_cache = StreamChangeCache(
|
|
||||||
"DeviceInboxStreamChangeCache",
|
|
||||||
self._device_inbox_id_gen.get_current_token(),
|
|
||||||
)
|
|
||||||
self._device_federation_outbox_stream_cache = StreamChangeCache(
|
|
||||||
"DeviceFederationOutboxStreamChangeCache",
|
|
||||||
self._device_inbox_id_gen.get_current_token(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._last_device_delete_cache = ExpiringCache(
|
|
||||||
cache_name="last_device_delete_cache",
|
|
||||||
clock=self._clock,
|
|
||||||
max_len=10000,
|
|
||||||
expiry_ms=30 * 60 * 1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
|
||||||
if stream_name == ToDeviceStream.NAME:
|
|
||||||
self._device_inbox_id_gen.advance(instance_name, token)
|
|
||||||
for row in rows:
|
|
||||||
if row.entity.startswith("@"):
|
|
||||||
self._device_inbox_stream_cache.entity_has_changed(
|
|
||||||
row.entity, token
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
|
||||||
row.entity, token
|
|
||||||
)
|
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
|
||||||
|
|
|
@ -14,43 +14,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.replication.tcp.streams import ReceiptsStream
|
|
||||||
from synapse.storage.database import DatabasePool
|
|
||||||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
pass
|
||||||
# We instantiate this first as the ReceiptsWorkerStore constructor
|
|
||||||
# needs to be able to call get_max_receipt_stream_id
|
|
||||||
self._receipts_id_gen = SlavedIdTracker(
|
|
||||||
db_conn, "receipts_linearized", "stream_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(database, db_conn, hs)
|
|
||||||
|
|
||||||
def get_max_receipt_stream_id(self):
|
|
||||||
return self._receipts_id_gen.get_current_token()
|
|
||||||
|
|
||||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
|
||||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
|
||||||
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
|
||||||
self.get_last_receipt_event_id_for_user.invalidate(
|
|
||||||
(user_id, room_id, receipt_type)
|
|
||||||
)
|
|
||||||
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
|
||||||
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
|
||||||
if stream_name == ReceiptsStream.NAME:
|
|
||||||
self._receipts_id_gen.advance(instance_name, token)
|
|
||||||
for row in rows:
|
|
||||||
self.invalidate_caches_for_receipt(
|
|
||||||
row.room_id, row.receipt_type, row.user_id
|
|
||||||
)
|
|
||||||
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
|
|
||||||
|
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
|
||||||
|
|
|
@ -51,11 +51,15 @@ from synapse.replication.tcp.commands import (
|
||||||
from synapse.replication.tcp.protocol import AbstractConnection
|
from synapse.replication.tcp.protocol import AbstractConnection
|
||||||
from synapse.replication.tcp.streams import (
|
from synapse.replication.tcp.streams import (
|
||||||
STREAMS_MAP,
|
STREAMS_MAP,
|
||||||
|
AccountDataStream,
|
||||||
BackfillStream,
|
BackfillStream,
|
||||||
CachesStream,
|
CachesStream,
|
||||||
EventsStream,
|
EventsStream,
|
||||||
FederationStream,
|
FederationStream,
|
||||||
|
ReceiptsStream,
|
||||||
Stream,
|
Stream,
|
||||||
|
TagAccountDataStream,
|
||||||
|
ToDeviceStream,
|
||||||
TypingStream,
|
TypingStream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -115,6 +119,14 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(stream, ToDeviceStream):
|
||||||
|
# Only add ToDeviceStream as a source on instances in charge of
|
||||||
|
# sending to device messages.
|
||||||
|
if hs.get_instance_name() in hs.config.worker.writers.to_device:
|
||||||
|
self._streams_to_replicate.append(stream)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(stream, TypingStream):
|
if isinstance(stream, TypingStream):
|
||||||
# Only add TypingStream as a source on the instance in charge of
|
# Only add TypingStream as a source on the instance in charge of
|
||||||
# typing.
|
# typing.
|
||||||
|
@ -123,6 +135,22 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
|
||||||
|
# Only add AccountDataStream and TagAccountDataStream as a source on the
|
||||||
|
# instance in charge of account_data persistence.
|
||||||
|
if hs.get_instance_name() in hs.config.worker.writers.account_data:
|
||||||
|
self._streams_to_replicate.append(stream)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(stream, ReceiptsStream):
|
||||||
|
# Only add ReceiptsStream as a source on the instance in charge of
|
||||||
|
# receipts.
|
||||||
|
if hs.get_instance_name() in hs.config.worker.writers.receipts:
|
||||||
|
self._streams_to_replicate.append(stream)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
# Only add any other streams if we're on master.
|
# Only add any other streams if we're on master.
|
||||||
if hs.config.worker_app is not None:
|
if hs.config.worker_app is not None:
|
||||||
continue
|
continue
|
||||||
|
|
18
synapse/res/templates/sso_auth_bad_user.html
Normal file
18
synapse/res/templates/sso_auth_bad_user.html
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Authentication Failed</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div>
|
||||||
|
<p>
|
||||||
|
We were unable to validate your <tt>{{server_name | e}}</tt> account via
|
||||||
|
single-sign-on (SSO), because the SSO Identity Provider returned
|
||||||
|
different details than when you logged in.
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
Try the operation again, and ensure that you use the same details on
|
||||||
|
the Identity Provider as when you log into your account.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
31
synapse/res/templates/sso_login_idp_picker.html
Normal file
31
synapse/res/templates/sso_login_idp_picker.html
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<link rel="stylesheet" href="/_matrix/static/client/login/style.css">
|
||||||
|
<title>{{server_name | e}} Login</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="container">
|
||||||
|
<h1 id="title">{{server_name | e}} Login</h1>
|
||||||
|
<div class="login_flow">
|
||||||
|
<p>Choose one of the following identity providers:</p>
|
||||||
|
<form>
|
||||||
|
<input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
|
||||||
|
<ul class="radiobuttons">
|
||||||
|
{% for p in providers %}
|
||||||
|
<li>
|
||||||
|
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
|
||||||
|
<label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
|
||||||
|
{% if p.idp_icon %}
|
||||||
|
<img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
|
||||||
|
{% endif %}
|
||||||
|
</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
<input type="submit" class="button button--full-width" id="button-submit" value="Submit">
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -15,6 +15,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||||
|
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
|
||||||
assert_requester_is_admin,
|
assert_requester_is_admin,
|
||||||
assert_user_is_admin,
|
assert_user_is_admin,
|
||||||
)
|
)
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
|
||||||
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
|
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request, room_id: str):
|
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
|
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request, user_id: str):
|
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
|
||||||
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request, server_name: str, media_id: str):
|
async def on_POST(
|
||||||
|
self, request: Request, server_name: str, media_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
class ProtectMediaByID(RestServlet):
|
||||||
|
"""Protect local media from being quarantined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
logging.info("Protecting local media by ID: %s", media_id)
|
||||||
|
|
||||||
|
# Quarantine this media id
|
||||||
|
await self.store.mark_local_media_as_safe(media_id)
|
||||||
|
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
class ListMediaInRoom(RestServlet):
|
class ListMediaInRoom(RestServlet):
|
||||||
"""Lists all of the media in a given room.
|
"""Lists all of the media in a given room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
|
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(self, request, room_id):
|
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
is_admin = await self.auth.is_server_admin(requester.user)
|
is_admin = await self.auth.is_server_admin(requester.user)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
|
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
|
||||||
class PurgeMediaCacheRestServlet(RestServlet):
|
class PurgeMediaCacheRestServlet(RestServlet):
|
||||||
PATTERNS = admin_patterns("/purge_media_cache")
|
PATTERNS = admin_patterns("/purge_media_cache")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request):
|
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
before_ts = parse_integer(request, "before_ts", required=True)
|
before_ts = parse_integer(request, "before_ts", required=True)
|
||||||
|
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
|
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_DELETE(self, request, server_name: str, media_id: str):
|
async def on_DELETE(
|
||||||
|
self, request: Request, server_name: str, media_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
if self.server_name != server_name:
|
if self.server_name != server_name:
|
||||||
|
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
|
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_POST(self, request, server_name: str):
|
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
before_ts = parse_integer(request, "before_ts", required=True)
|
before_ts = parse_integer(request, "before_ts", required=True)
|
||||||
|
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
|
||||||
return 200, {"deleted_media": deleted_media, "total": total}
|
return 200, {"deleted_media": deleted_media, "total": total}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets_for_media_repo(hs, http_server):
|
def register_servlets_for_media_repo(hs: "HomeServer", http_server):
|
||||||
"""
|
"""
|
||||||
Media repo specific APIs.
|
Media repo specific APIs.
|
||||||
"""
|
"""
|
||||||
|
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
|
||||||
QuarantineMediaInRoom(hs).register(http_server)
|
QuarantineMediaInRoom(hs).register(http_server)
|
||||||
QuarantineMediaByID(hs).register(http_server)
|
QuarantineMediaByID(hs).register(http_server)
|
||||||
QuarantineMediaByUser(hs).register(http_server)
|
QuarantineMediaByUser(hs).register(http_server)
|
||||||
|
ProtectMediaByID(hs).register(http_server)
|
||||||
ListMediaInRoom(hs).register(http_server)
|
ListMediaInRoom(hs).register(http_server)
|
||||||
DeleteMediaByID(hs).register(http_server)
|
DeleteMediaByID(hs).register(http_server)
|
||||||
DeleteMediaByDateSize(hs).register(http_server)
|
DeleteMediaByDateSize(hs).register(http_server)
|
||||||
|
|
|
@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet):
|
||||||
|
|
||||||
if deactivate and not user["deactivated"]:
|
if deactivate and not user["deactivated"]:
|
||||||
await self.deactivate_account_handler.deactivate_account(
|
await self.deactivate_account_handler.deactivate_account(
|
||||||
target_user.to_string(), False
|
target_user.to_string(), False, requester, by_admin=True
|
||||||
)
|
)
|
||||||
elif not deactivate and user["deactivated"]:
|
elif not deactivate and user["deactivated"]:
|
||||||
if "password" not in body:
|
if "password" not in body:
|
||||||
|
@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet):
|
||||||
class DeactivateAccountRestServlet(RestServlet):
|
class DeactivateAccountRestServlet(RestServlet):
|
||||||
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
|
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
if not self.is_mine(UserID.from_string(target_user_id)):
|
||||||
|
raise SynapseError(400, "Can only deactivate local users")
|
||||||
|
|
||||||
|
if not await self.store.get_user_by_id(target_user_id):
|
||||||
|
raise NotFoundError("User not found")
|
||||||
|
|
||||||
async def on_POST(self, request, target_user_id):
|
|
||||||
await assert_requester_is_admin(self.auth, request)
|
|
||||||
body = parse_json_object_from_request(request, allow_empty_body=True)
|
body = parse_json_object_from_request(request, allow_empty_body=True)
|
||||||
erase = body.get("erase", False)
|
erase = body.get("erase", False)
|
||||||
if not isinstance(erase, bool):
|
if not isinstance(erase, bool):
|
||||||
|
@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
Codes.BAD_JSON,
|
Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
UserID.from_string(target_user_id)
|
|
||||||
|
|
||||||
result = await self._deactivate_account_handler.deactivate_account(
|
result = await self._deactivate_account_handler.deactivate_account(
|
||||||
target_user_id, erase
|
target_user_id, erase, requester, by_admin=True
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
id_server_unbind_result = "success"
|
id_server_unbind_result = "success"
|
||||||
|
@ -714,13 +722,6 @@ class UserMembershipRestServlet(RestServlet):
|
||||||
async def on_GET(self, request, user_id):
|
async def on_GET(self, request, user_id):
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
if not self.is_mine(UserID.from_string(user_id)):
|
|
||||||
raise SynapseError(400, "Can only lookup local users")
|
|
||||||
|
|
||||||
user = await self.store.get_user_by_id(user_id)
|
|
||||||
if user is None:
|
|
||||||
raise NotFoundError("Unknown user")
|
|
||||||
|
|
||||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||||
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
|
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
|
|
@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BaseSSORedirectServlet(RestServlet):
|
class SsoRedirectServlet(RestServlet):
|
||||||
"""Common base class for /login/sso/redirect impls"""
|
|
||||||
|
|
||||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
# make sure that the relevant handlers are instantiated, so that they
|
||||||
|
# register themselves with the main SSOHandler.
|
||||||
|
if hs.config.cas_enabled:
|
||||||
|
hs.get_cas_handler()
|
||||||
|
if hs.config.saml2_enabled:
|
||||||
|
hs.get_saml_handler()
|
||||||
|
if hs.config.oidc_enabled:
|
||||||
|
hs.get_oidc_handler()
|
||||||
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
async def on_GET(self, request: SynapseRequest):
|
async def on_GET(self, request: SynapseRequest):
|
||||||
args = request.args
|
client_redirect_url = parse_string(
|
||||||
if b"redirectUrl" not in args:
|
request, "redirectUrl", required=True, encoding=None
|
||||||
return 400, "Redirect URL not specified for SSO auth"
|
)
|
||||||
client_redirect_url = args[b"redirectUrl"][0]
|
sso_url = await self._sso_handler.handle_redirect_request(
|
||||||
sso_url = await self.get_sso_url(request, client_redirect_url)
|
request, client_redirect_url
|
||||||
|
)
|
||||||
|
logger.info("Redirecting to %s", sso_url)
|
||||||
request.redirect(sso_url)
|
request.redirect(sso_url)
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
async def get_sso_url(
|
|
||||||
self, request: SynapseRequest, client_redirect_url: bytes
|
|
||||||
) -> bytes:
|
|
||||||
"""Get the URL to redirect to, to perform SSO auth
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The client request to redirect.
|
|
||||||
client_redirect_url: the URL that we should redirect the
|
|
||||||
client to when everything is done
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
URL to redirect to
|
|
||||||
"""
|
|
||||||
# to be implemented by subclasses
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class CasRedirectServlet(BaseSSORedirectServlet):
|
|
||||||
def __init__(self, hs):
|
|
||||||
self._cas_handler = hs.get_cas_handler()
|
|
||||||
|
|
||||||
async def get_sso_url(
|
|
||||||
self, request: SynapseRequest, client_redirect_url: bytes
|
|
||||||
) -> bytes:
|
|
||||||
return self._cas_handler.get_redirect_url(
|
|
||||||
{"redirectUrl": client_redirect_url}
|
|
||||||
).encode("ascii")
|
|
||||||
|
|
||||||
|
|
||||||
class CasTicketServlet(RestServlet):
|
class CasTicketServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
|
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
|
||||||
|
@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
|
||||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
self._saml_handler = hs.get_saml_handler()
|
|
||||||
|
|
||||||
async def get_sso_url(
|
|
||||||
self, request: SynapseRequest, client_redirect_url: bytes
|
|
||||||
) -> bytes:
|
|
||||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
|
||||||
|
|
||||||
|
|
||||||
class OIDCRedirectServlet(BaseSSORedirectServlet):
|
|
||||||
"""Implementation for /login/sso/redirect for the OIDC login flow."""
|
|
||||||
|
|
||||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
self._oidc_handler = hs.get_oidc_handler()
|
|
||||||
|
|
||||||
async def get_sso_url(
|
|
||||||
self, request: SynapseRequest, client_redirect_url: bytes
|
|
||||||
) -> bytes:
|
|
||||||
return await self._oidc_handler.handle_redirect_request(
|
|
||||||
request, client_redirect_url
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
LoginRestServlet(hs).register(http_server)
|
LoginRestServlet(hs).register(http_server)
|
||||||
|
SsoRedirectServlet(hs).register(http_server)
|
||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
CasRedirectServlet(hs).register(http_server)
|
|
||||||
CasTicketServlet(hs).register(http_server)
|
CasTicketServlet(hs).register(http_server)
|
||||||
elif hs.config.saml2_enabled:
|
|
||||||
SAMLRedirectServlet(hs).register(http_server)
|
|
||||||
elif hs.config.oidc_enabled:
|
|
||||||
OIDCRedirectServlet(hs).register(http_server)
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import synapse.server
|
import synapse.server
|
||||||
|
@ -350,8 +350,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
||||||
# provided.
|
# provided.
|
||||||
if server:
|
if server:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", 0)
|
limit = parse_integer(request, "limit", 0)
|
||||||
since_token = parse_string(request, "since", None)
|
since_token = parse_string(request, "since", None)
|
||||||
|
@ -362,6 +360,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
||||||
|
|
||||||
handler = self.hs.get_room_list_handler()
|
handler = self.hs.get_room_list_handler()
|
||||||
if server and server != self.hs.config.server_name:
|
if server and server != self.hs.config.server_name:
|
||||||
|
# Ensure the server is valid.
|
||||||
|
try:
|
||||||
|
parse_and_validate_server_name(server)
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = await handler.get_remote_public_room_list(
|
data = await handler.get_remote_public_room_list(
|
||||||
server, limit=limit, since_token=since_token
|
server, limit=limit, since_token=since_token
|
||||||
|
@ -405,6 +411,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
||||||
|
|
||||||
handler = self.hs.get_room_list_handler()
|
handler = self.hs.get_room_list_handler()
|
||||||
if server and server != self.hs.config.server_name:
|
if server and server != self.hs.config.server_name:
|
||||||
|
# Ensure the server is valid.
|
||||||
|
try:
|
||||||
|
parse_and_validate_server_name(server)
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = await handler.get_remote_public_room_list(
|
data = await handler.get_remote_public_room_list(
|
||||||
server,
|
server,
|
||||||
|
|
|
@ -20,9 +20,6 @@ from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from synapse.app.homeserver import HomeServer
|
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
Codes,
|
Codes,
|
||||||
|
@ -31,6 +28,7 @@ from synapse.api.errors import (
|
||||||
ThreepidValidationError,
|
ThreepidValidationError,
|
||||||
)
|
)
|
||||||
from synapse.config.emailconfig import ThreepidBehaviour
|
from synapse.config.emailconfig import ThreepidBehaviour
|
||||||
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||||
from synapse.http.server import finish_request, respond_with_html
|
from synapse.http.server import finish_request, respond_with_html
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
|
||||||
|
|
||||||
from ._base import client_patterns, interactive_auth_handler
|
from ._base import client_patterns, interactive_auth_handler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
try:
|
try:
|
||||||
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
|
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester,
|
requester, request, body, "modify your account password",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"modify your account password",
|
|
||||||
)
|
)
|
||||||
except InteractiveAuthIncompleteError as e:
|
except InteractiveAuthIncompleteError as e:
|
||||||
# The user needs to provide more steps to complete auth, but
|
# The user needs to provide more steps to complete auth, but
|
||||||
|
@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
|
||||||
if new_password:
|
if new_password:
|
||||||
password_hash = await self.auth_handler.hash(new_password)
|
password_hash = await self.auth_handler.hash(new_password)
|
||||||
await self.auth_handler.set_session_data(
|
await self.auth_handler.set_session_data(
|
||||||
e.session_id, "password_hash", password_hash
|
e.session_id,
|
||||||
|
UIAuthSessionDataConstants.PASSWORD_HASH,
|
||||||
|
password_hash,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
|
||||||
[[LoginType.EMAIL_IDENTITY]],
|
[[LoginType.EMAIL_IDENTITY]],
|
||||||
request,
|
request,
|
||||||
body,
|
body,
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"modify your account password",
|
"modify your account password",
|
||||||
)
|
)
|
||||||
except InteractiveAuthIncompleteError as e:
|
except InteractiveAuthIncompleteError as e:
|
||||||
|
@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
|
||||||
if new_password:
|
if new_password:
|
||||||
password_hash = await self.auth_handler.hash(new_password)
|
password_hash = await self.auth_handler.hash(new_password)
|
||||||
await self.auth_handler.set_session_data(
|
await self.auth_handler.set_session_data(
|
||||||
e.session_id, "password_hash", password_hash
|
e.session_id,
|
||||||
|
UIAuthSessionDataConstants.PASSWORD_HASH,
|
||||||
|
password_hash,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
password_hash = await self.auth_handler.hash(new_password)
|
password_hash = await self.auth_handler.hash(new_password)
|
||||||
elif session_id is not None:
|
elif session_id is not None:
|
||||||
password_hash = await self.auth_handler.get_session_data(
|
password_hash = await self.auth_handler.get_session_data(
|
||||||
session_id, "password_hash", None
|
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# UI validation was skipped, but the request did not include a new
|
# UI validation was skipped, but the request did not include a new
|
||||||
|
@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
# allow ASes to deactivate their own users
|
# allow ASes to deactivate their own users
|
||||||
if requester.app_service:
|
if requester.app_service:
|
||||||
await self._deactivate_account_handler.deactivate_account(
|
await self._deactivate_account_handler.deactivate_account(
|
||||||
requester.user.to_string(), erase
|
requester.user.to_string(), erase, requester
|
||||||
)
|
)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
await self.auth_handler.validate_user_via_ui_auth(
|
await self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester,
|
requester, request, body, "deactivate your account",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"deactivate your account",
|
|
||||||
)
|
)
|
||||||
result = await self._deactivate_account_handler.deactivate_account(
|
result = await self._deactivate_account_handler.deactivate_account(
|
||||||
requester.user.to_string(), erase, id_server=body.get("id_server")
|
requester.user.to_string(),
|
||||||
|
erase,
|
||||||
|
requester,
|
||||||
|
id_server=body.get("id_server"),
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
id_server_unbind_result = "success"
|
id_server_unbind_result = "success"
|
||||||
|
@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
|
||||||
assert_valid_client_secret(client_secret)
|
assert_valid_client_secret(client_secret)
|
||||||
|
|
||||||
await self.auth_handler.validate_user_via_ui_auth(
|
await self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester,
|
requester, request, body, "add a third-party identifier to your account",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"add a third-party identifier to your account",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
validation_session = await self.identity_handler.validate_threepid_session(
|
validation_session = await self.identity_handler.validate_threepid_session(
|
||||||
|
|
|
@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.notifier = hs.get_notifier()
|
self.handler = hs.get_account_data_handler()
|
||||||
self._is_worker = hs.config.worker_app is not None
|
|
||||||
|
|
||||||
async def on_PUT(self, request, user_id, account_data_type):
|
async def on_PUT(self, request, user_id, account_data_type):
|
||||||
if self._is_worker:
|
|
||||||
raise Exception("Cannot handle PUT /account_data on worker")
|
|
||||||
|
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
if user_id != requester.user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add account data for other users.")
|
raise AuthError(403, "Cannot add account data for other users.")
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
max_id = await self.store.add_account_data_for_user(
|
await self.handler.add_account_data_for_user(user_id, account_data_type, body)
|
||||||
user_id, account_data_type, body
|
|
||||||
)
|
|
||||||
|
|
||||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.notifier = hs.get_notifier()
|
self.handler = hs.get_account_data_handler()
|
||||||
self._is_worker = hs.config.worker_app is not None
|
|
||||||
|
|
||||||
async def on_PUT(self, request, user_id, room_id, account_data_type):
|
async def on_PUT(self, request, user_id, room_id, account_data_type):
|
||||||
if self._is_worker:
|
|
||||||
raise Exception("Cannot handle PUT /account_data on worker")
|
|
||||||
|
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
if user_id != requester.user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add account data for other users.")
|
raise AuthError(403, "Cannot add account data for other users.")
|
||||||
|
@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
|
||||||
" Use /rooms/!roomId:server.name/read_markers",
|
" Use /rooms/!roomId:server.name/read_markers",
|
||||||
)
|
)
|
||||||
|
|
||||||
max_id = await self.store.add_account_data_to_room(
|
await self.handler.add_account_data_to_room(
|
||||||
user_id, room_id, account_data_type, body
|
user_id, room_id, account_data_type, body
|
||||||
)
|
)
|
||||||
|
|
||||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
async def on_GET(self, request, user_id, room_id, account_data_type):
|
async def on_GET(self, request, user_id, room_id, account_data_type):
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
|
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
# SSO configuration.
|
|
||||||
self._cas_enabled = hs.config.cas_enabled
|
|
||||||
if self._cas_enabled:
|
|
||||||
self._cas_handler = hs.get_cas_handler()
|
|
||||||
self._cas_server_url = hs.config.cas_server_url
|
|
||||||
self._cas_service_url = hs.config.cas_service_url
|
|
||||||
self._saml_enabled = hs.config.saml2_enabled
|
|
||||||
if self._saml_enabled:
|
|
||||||
self._saml_handler = hs.get_saml_handler()
|
|
||||||
self._oidc_enabled = hs.config.oidc_enabled
|
|
||||||
if self._oidc_enabled:
|
|
||||||
self._oidc_handler = hs.get_oidc_handler()
|
|
||||||
self._cas_server_url = hs.config.cas_server_url
|
|
||||||
self._cas_service_url = hs.config.cas_service_url
|
|
||||||
|
|
||||||
self.recaptcha_template = hs.config.recaptcha_template
|
self.recaptcha_template = hs.config.recaptcha_template
|
||||||
self.terms_template = hs.config.terms_template
|
self.terms_template = hs.config.terms_template
|
||||||
self.success_template = hs.config.fallback_success_template
|
self.success_template = hs.config.fallback_success_template
|
||||||
|
@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
|
||||||
elif stagetype == LoginType.SSO:
|
elif stagetype == LoginType.SSO:
|
||||||
# Display a confirmation page which prompts the user to
|
# Display a confirmation page which prompts the user to
|
||||||
# re-authenticate with their SSO provider.
|
# re-authenticate with their SSO provider.
|
||||||
if self._cas_enabled:
|
html = await self.auth_handler.start_sso_ui_auth(request, session)
|
||||||
# Generate a request to CAS that redirects back to an endpoint
|
|
||||||
# to verify the successful authentication.
|
|
||||||
sso_redirect_url = self._cas_handler.get_redirect_url(
|
|
||||||
{"session": session},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self._saml_enabled:
|
|
||||||
# Some SAML identity providers (e.g. Google) require a
|
|
||||||
# RelayState parameter on requests. It is not necessary here, so
|
|
||||||
# pass in a dummy redirect URL (which will never get used).
|
|
||||||
client_redirect_url = b"unused"
|
|
||||||
sso_redirect_url = self._saml_handler.handle_redirect_request(
|
|
||||||
client_redirect_url, session
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self._oidc_enabled:
|
|
||||||
client_redirect_url = b""
|
|
||||||
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
|
|
||||||
request, client_redirect_url, session
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise SynapseError(400, "Homeserver not configured for SSO.")
|
|
||||||
|
|
||||||
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise SynapseError(404, "Unknown auth stage type")
|
raise SynapseError(404, "Unknown auth stage type")
|
||||||
|
@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
|
||||||
authdict = {"response": response, "session": session}
|
authdict = {"response": response, "session": session}
|
||||||
|
|
||||||
success = await self.auth_handler.add_oob_auth(
|
success = await self.auth_handler.add_oob_auth(
|
||||||
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
|
LoginType.RECAPTCHA, authdict, request.getClientIP()
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
|
@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
|
||||||
authdict = {"session": session}
|
authdict = {"session": session}
|
||||||
|
|
||||||
success = await self.auth_handler.add_oob_auth(
|
success = await self.auth_handler.add_oob_auth(
|
||||||
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
|
LoginType.TERMS, authdict, request.getClientIP()
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
|
|
|
@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||||
assert_params_in_dict(body, ["devices"])
|
assert_params_in_dict(body, ["devices"])
|
||||||
|
|
||||||
await self.auth_handler.validate_user_via_ui_auth(
|
await self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester,
|
requester, request, body, "remove device(s) from your account",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"remove device(s) from your account",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.device_handler.delete_devices(
|
await self.device_handler.delete_devices(
|
||||||
|
@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
await self.auth_handler.validate_user_via_ui_auth(
|
await self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester,
|
requester, request, body, "remove a device from your account",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"remove a device from your account",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.device_handler.delete_device(requester.user.to_string(), device_id)
|
await self.device_handler.delete_device(requester.user.to_string(), device_id)
|
||||||
|
|
|
@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
await self.auth_handler.validate_user_via_ui_auth(
|
await self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester,
|
requester, request, body, "add a device signing key to your account",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"add a device signing key to your account",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
|
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
|
||||||
|
|
|
@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
|
||||||
from synapse.config.registration import RegistrationConfig
|
from synapse.config.registration import RegistrationConfig
|
||||||
from synapse.config.server import is_threepid_reserved
|
from synapse.config.server import is_threepid_reserved
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||||
from synapse.http.server import finish_request, respond_with_html
|
from synapse.http.server import finish_request, respond_with_html
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||||
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
|
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
ip = self.hs.get_ip_from_request(request)
|
ip = request.getClientIP()
|
||||||
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
||||||
await wait_deferred
|
await wait_deferred
|
||||||
|
|
||||||
|
@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
# user here. We carry on and go through the auth checks though,
|
# user here. We carry on and go through the auth checks though,
|
||||||
# for paranoia.
|
# for paranoia.
|
||||||
registered_user_id = await self.auth_handler.get_session_data(
|
registered_user_id = await self.auth_handler.get_session_data(
|
||||||
session_id, "registered_user_id", None
|
session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
|
||||||
)
|
)
|
||||||
# Extract the previously-hashed password from the session.
|
# Extract the previously-hashed password from the session.
|
||||||
password_hash = await self.auth_handler.get_session_data(
|
password_hash = await self.auth_handler.get_session_data(
|
||||||
session_id, "password_hash", None
|
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that the username is valid.
|
# Ensure that the username is valid.
|
||||||
|
@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
# not this will raise a user-interactive auth error.
|
# not this will raise a user-interactive auth error.
|
||||||
try:
|
try:
|
||||||
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
|
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
|
||||||
self._registration_flows,
|
self._registration_flows, request, body, "register a new account",
|
||||||
request,
|
|
||||||
body,
|
|
||||||
self.hs.get_ip_from_request(request),
|
|
||||||
"register a new account",
|
|
||||||
)
|
)
|
||||||
except InteractiveAuthIncompleteError as e:
|
except InteractiveAuthIncompleteError as e:
|
||||||
# The user needs to provide more steps to complete auth.
|
# The user needs to provide more steps to complete auth.
|
||||||
|
@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
|
||||||
if not password_hash and password:
|
if not password_hash and password:
|
||||||
password_hash = await self.auth_handler.hash(password)
|
password_hash = await self.auth_handler.hash(password)
|
||||||
await self.auth_handler.set_session_data(
|
await self.auth_handler.set_session_data(
|
||||||
e.session_id, "password_hash", password_hash
|
e.session_id,
|
||||||
|
UIAuthSessionDataConstants.PASSWORD_HASH,
|
||||||
|
password_hash,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
|
||||||
# Remember that the user account has been registered (and the user
|
# Remember that the user account has been registered (and the user
|
||||||
# ID it was registered with, since it might not have been specified).
|
# ID it was registered with, since it might not have been specified).
|
||||||
await self.auth_handler.set_session_data(
|
await self.auth_handler.set_session_data(
|
||||||
session_id, "registered_user_id", registered_user_id
|
session_id,
|
||||||
|
UIAuthSessionDataConstants.REGISTERED_USER_ID,
|
||||||
|
registered_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
registered = True
|
registered = True
|
||||||
|
|
|
@ -58,8 +58,7 @@ class TagServlet(RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.handler = hs.get_account_data_handler()
|
||||||
self.notifier = hs.get_notifier()
|
|
||||||
|
|
||||||
async def on_PUT(self, request, user_id, room_id, tag):
|
async def on_PUT(self, request, user_id, room_id, tag):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
@ -68,9 +67,7 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
|
await self.handler.add_tag_to_room(user_id, room_id, tag, body)
|
||||||
|
|
||||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
@ -79,9 +76,7 @@ class TagServlet(RestServlet):
|
||||||
if user_id != requester.user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add tags for other users.")
|
raise AuthError(403, "Cannot add tags for other users.")
|
||||||
|
|
||||||
max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
|
await self.handler.remove_tag_from_room(user_id, room_id, tag)
|
||||||
|
|
||||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2019 New Vector Ltd
|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,10 +17,11 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Awaitable
|
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet.interfaces import IConsumer
|
from twisted.internet.interfaces import IConsumer
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||||
from synapse.http.server import finish_request, respond_with_json
|
from synapse.http.server import finish_request, respond_with_json
|
||||||
|
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def parse_media_id(request):
|
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||||
try:
|
try:
|
||||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||||
# clients that parse the URL to see content type.
|
# clients that parse the URL to see content type.
|
||||||
|
@ -69,7 +70,7 @@ def parse_media_id(request):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def respond_404(request):
|
def respond_404(request: Request) -> None:
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request,
|
request,
|
||||||
404,
|
404,
|
||||||
|
@ -79,8 +80,12 @@ def respond_404(request):
|
||||||
|
|
||||||
|
|
||||||
async def respond_with_file(
|
async def respond_with_file(
|
||||||
request, media_type, file_path, file_size=None, upload_name=None
|
request: Request,
|
||||||
):
|
media_type: str,
|
||||||
|
file_path: str,
|
||||||
|
file_size: Optional[int] = None,
|
||||||
|
upload_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
logger.debug("Responding with %r", file_path)
|
logger.debug("Responding with %r", file_path)
|
||||||
|
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
|
@ -98,15 +103,20 @@ async def respond_with_file(
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
|
|
||||||
def add_file_headers(request, media_type, file_size, upload_name):
|
def add_file_headers(
|
||||||
|
request: Request,
|
||||||
|
media_type: str,
|
||||||
|
file_size: Optional[int],
|
||||||
|
upload_name: Optional[str],
|
||||||
|
) -> None:
|
||||||
"""Adds the correct response headers in preparation for responding with the
|
"""Adds the correct response headers in preparation for responding with the
|
||||||
media.
|
media.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (twisted.web.http.Request)
|
request
|
||||||
media_type (str): The media/content type.
|
media_type: The media/content type.
|
||||||
file_size (int): Size in bytes of the media, if known.
|
file_size: Size in bytes of the media, if known.
|
||||||
upload_name (str): The name of the requested file, if any.
|
upload_name: The name of the requested file, if any.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _quote(x):
|
def _quote(x):
|
||||||
|
@ -153,6 +163,7 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
||||||
# select private. don't bother setting Expires as all our
|
# select private. don't bother setting Expires as all our
|
||||||
# clients are smart enough to be happy with Cache-Control
|
# clients are smart enough to be happy with Cache-Control
|
||||||
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
||||||
|
if file_size is not None:
|
||||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||||
|
|
||||||
# Tell web crawlers to not index, archive, or follow links in media. This
|
# Tell web crawlers to not index, archive, or follow links in media. This
|
||||||
|
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _can_encode_filename_as_token(x):
|
def _can_encode_filename_as_token(x: str) -> bool:
|
||||||
for c in x:
|
for c in x:
|
||||||
# from RFC2616:
|
# from RFC2616:
|
||||||
#
|
#
|
||||||
|
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
|
||||||
|
|
||||||
|
|
||||||
async def respond_with_responder(
|
async def respond_with_responder(
|
||||||
request, responder, media_type, file_size, upload_name=None
|
request: Request,
|
||||||
):
|
responder: "Optional[Responder]",
|
||||||
|
media_type: str,
|
||||||
|
file_size: Optional[int],
|
||||||
|
upload_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""Responds to the request with given responder. If responder is None then
|
"""Responds to the request with given responder. If responder is None then
|
||||||
returns 404.
|
returns 404.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (twisted.web.http.Request)
|
request
|
||||||
responder (Responder|None)
|
responder
|
||||||
media_type (str): The media/content type.
|
media_type: The media/content type.
|
||||||
file_size (int|None): Size in bytes of the media. If not known it should be None
|
file_size: Size in bytes of the media. If not known it should be None
|
||||||
upload_name (str|None): The name of the requested file, if any.
|
upload_name: The name of the requested file, if any.
|
||||||
"""
|
"""
|
||||||
if request._disconnected:
|
if request._disconnected:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -308,22 +323,22 @@ class FileInfo:
|
||||||
self.thumbnail_type = thumbnail_type
|
self.thumbnail_type = thumbnail_type
|
||||||
|
|
||||||
|
|
||||||
def get_filename_from_headers(headers):
|
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Get the filename of the downloaded file by inspecting the
|
Get the filename of the downloaded file by inspecting the
|
||||||
Content-Disposition HTTP header.
|
Content-Disposition HTTP header.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
headers (dict[bytes, list[bytes]]): The HTTP request headers.
|
headers: The HTTP request headers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Unicode string of the filename, or None.
|
The filename, or None.
|
||||||
"""
|
"""
|
||||||
content_disposition = headers.get(b"Content-Disposition", [b""])
|
content_disposition = headers.get(b"Content-Disposition", [b""])
|
||||||
|
|
||||||
# No header, bail out.
|
# No header, bail out.
|
||||||
if not content_disposition[0]:
|
if not content_disposition[0]:
|
||||||
return
|
return None
|
||||||
|
|
||||||
_, params = _parse_header(content_disposition[0])
|
_, params = _parse_header(content_disposition[0])
|
||||||
|
|
||||||
|
@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
|
||||||
return upload_name
|
return upload_name
|
||||||
|
|
||||||
|
|
||||||
def _parse_header(line):
|
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
|
||||||
"""Parse a Content-type like header.
|
"""Parse a Content-type like header.
|
||||||
|
|
||||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
line (bytes): header to be parsed
|
line: header to be parsed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bytes, dict[bytes, bytes]]:
|
The main content-type, followed by the parameter dictionary
|
||||||
the main content-type, followed by the parameter dictionary
|
|
||||||
"""
|
"""
|
||||||
parts = _parseparam(b";" + line)
|
parts = _parseparam(b";" + line)
|
||||||
key = next(parts)
|
key = next(parts)
|
||||||
|
@ -386,16 +400,16 @@ def _parse_header(line):
|
||||||
return key, pdict
|
return key, pdict
|
||||||
|
|
||||||
|
|
||||||
def _parseparam(s):
|
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
|
||||||
"""Generator which splits the input on ;, respecting double-quoted sequences
|
"""Generator which splits the input on ;, respecting double-quoted sequences
|
||||||
|
|
||||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s (bytes): header to be parsed
|
s: header to be parsed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterable[bytes]: the split input
|
The split input
|
||||||
"""
|
"""
|
||||||
while s[:1] == b";":
|
while s[:1] == b";":
|
||||||
s = s[1:]
|
s = s[1:]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2018 Will Hunt <will@half-shot.uk>
|
# Copyright 2018 Will Hunt <will@half-shot.uk>
|
||||||
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,22 +15,29 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class MediaConfigResource(DirectServeJsonResource):
|
class MediaConfigResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = hs.get_config()
|
config = hs.get_config()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||||
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request: Request) -> None:
|
||||||
await self.auth.get_user_by_req(request)
|
await self.auth.get_user_by_req(request)
|
||||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_OPTIONS(self, request):
|
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,24 +14,31 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
import synapse.http.servlet
|
|
||||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||||
|
from synapse.http.servlet import parse_boolean
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_404
|
from ._base import parse_media_id, respond_404
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DownloadResource(DirectServeJsonResource):
|
class DownloadResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request: Request) -> None:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
b"Content-Security-Policy",
|
b"Content-Security-Policy",
|
||||||
|
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
await self.media_repo.get_local_media(request, media_id, name)
|
await self.media_repo.get_local_media(request, media_id, name)
|
||||||
else:
|
else:
|
||||||
allow_remote = synapse.http.servlet.parse_boolean(
|
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
||||||
request, "allow_remote", default=True
|
|
||||||
)
|
|
||||||
if not allow_remote:
|
if not allow_remote:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -16,11 +17,12 @@
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||||
|
|
||||||
|
|
||||||
def _wrap_in_base_path(func):
|
def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
|
||||||
"""Takes a function that returns a relative path and turns it into an
|
"""Takes a function that returns a relative path and turns it into an
|
||||||
absolute path based on the location of the primary media store
|
absolute path based on the location of the primary media store
|
||||||
"""
|
"""
|
||||||
|
@ -41,12 +43,18 @@ class MediaFilePaths:
|
||||||
to write to the backup media store (when one is configured)
|
to write to the backup media store (when one is configured)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, primary_base_path):
|
def __init__(self, primary_base_path: str):
|
||||||
self.base_path = primary_base_path
|
self.base_path = primary_base_path
|
||||||
|
|
||||||
def default_thumbnail_rel(
|
def default_thumbnail_rel(
|
||||||
self, default_top_level, default_sub_type, width, height, content_type, method
|
self,
|
||||||
):
|
default_top_level: str,
|
||||||
|
default_sub_type: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
content_type: str,
|
||||||
|
method: str,
|
||||||
|
) -> str:
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
|
@ -55,12 +63,14 @@ class MediaFilePaths:
|
||||||
|
|
||||||
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||||
|
|
||||||
def local_media_filepath_rel(self, media_id):
|
def local_media_filepath_rel(self, media_id: str) -> str:
|
||||||
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
|
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
|
||||||
|
|
||||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||||
|
|
||||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
|
def local_media_thumbnail_rel(
|
||||||
|
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||||
|
) -> str:
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
|
@ -86,7 +96,7 @@ class MediaFilePaths:
|
||||||
media_id[4:],
|
media_id[4:],
|
||||||
)
|
)
|
||||||
|
|
||||||
def remote_media_filepath_rel(self, server_name, file_id):
|
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
|
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
|
||||||
)
|
)
|
||||||
|
@ -94,8 +104,14 @@ class MediaFilePaths:
|
||||||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||||
|
|
||||||
def remote_media_thumbnail_rel(
|
def remote_media_thumbnail_rel(
|
||||||
self, server_name, file_id, width, height, content_type, method
|
self,
|
||||||
):
|
server_name: str,
|
||||||
|
file_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
content_type: str,
|
||||||
|
method: str,
|
||||||
|
) -> str:
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
|
@ -113,7 +129,7 @@ class MediaFilePaths:
|
||||||
# Should be removed after some time, when most of the thumbnails are stored
|
# Should be removed after some time, when most of the thumbnails are stored
|
||||||
# using the new path.
|
# using the new path.
|
||||||
def remote_media_thumbnail_rel_legacy(
|
def remote_media_thumbnail_rel_legacy(
|
||||||
self, server_name, file_id, width, height, content_type
|
self, server_name: str, file_id: str, width: int, height: int, content_type: str
|
||||||
):
|
):
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||||
|
@ -126,7 +142,7 @@ class MediaFilePaths:
|
||||||
file_name,
|
file_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path,
|
self.base_path,
|
||||||
"remote_thumbnail",
|
"remote_thumbnail",
|
||||||
|
@ -136,7 +152,7 @@ class MediaFilePaths:
|
||||||
file_id[4:],
|
file_id[4:],
|
||||||
)
|
)
|
||||||
|
|
||||||
def url_cache_filepath_rel(self, media_id):
|
def url_cache_filepath_rel(self, media_id: str) -> str:
|
||||||
if NEW_FORMAT_ID_RE.match(media_id):
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
# Media id is of the form <DATE><RANDOM_STRING>
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
@ -146,7 +162,7 @@ class MediaFilePaths:
|
||||||
|
|
||||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||||
|
|
||||||
def url_cache_filepath_dirs_to_delete(self, media_id):
|
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||||
"The dirs to try and remove if we delete the media_id file"
|
"The dirs to try and remove if we delete the media_id file"
|
||||||
if NEW_FORMAT_ID_RE.match(media_id):
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
|
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
|
||||||
|
@ -156,7 +172,9 @@ class MediaFilePaths:
|
||||||
os.path.join(self.base_path, "url_cache", media_id[0:2]),
|
os.path.join(self.base_path, "url_cache", media_id[0:2]),
|
||||||
]
|
]
|
||||||
|
|
||||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
|
def url_cache_thumbnail_rel(
|
||||||
|
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||||
|
) -> str:
|
||||||
# Media id is of the form <DATE><RANDOM_STRING>
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
|
||||||
|
@ -178,7 +196,7 @@ class MediaFilePaths:
|
||||||
|
|
||||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||||
|
|
||||||
def url_cache_thumbnail_directory(self, media_id):
|
def url_cache_thumbnail_directory(self, media_id: str) -> str:
|
||||||
# Media id is of the form <DATE><RANDOM_STRING>
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
|
||||||
|
@ -195,7 +213,7 @@ class MediaFilePaths:
|
||||||
media_id[4:],
|
media_id[4:],
|
||||||
)
|
)
|
||||||
|
|
||||||
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||||
"The dirs to try and remove if we delete the media_id thumbnails"
|
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||||
# Media id is of the form <DATE><RANDOM_STRING>
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2018 New Vector Ltd
|
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,12 +13,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import IO, Dict, List, Optional, Tuple
|
from io import BytesIO
|
||||||
|
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import twisted.internet.error
|
import twisted.internet.error
|
||||||
import twisted.web.http
|
import twisted.web.http
|
||||||
|
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
|
||||||
from .thumbnailer import Thumbnailer, ThumbnailError
|
from .thumbnailer import Thumbnailer, ThumbnailError
|
||||||
from .upload_resource import UploadResource
|
from .upload_resource import UploadResource
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
class MediaRepository:
|
class MediaRepository:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
|
@ -73,16 +76,16 @@ class MediaRepository:
|
||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.max_image_pixels = hs.config.max_image_pixels
|
self.max_image_pixels = hs.config.max_image_pixels
|
||||||
|
|
||||||
self.primary_base_path = hs.config.media_store_path
|
self.primary_base_path = hs.config.media_store_path # type: str
|
||||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
||||||
|
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||||
|
|
||||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||||
|
|
||||||
self.recently_accessed_remotes = set()
|
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
|
||||||
self.recently_accessed_locals = set()
|
self.recently_accessed_locals = set() # type: Set[str]
|
||||||
|
|
||||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
|
@ -113,7 +116,7 @@ class MediaRepository:
|
||||||
"update_recently_accessed_media", self._update_recently_accessed
|
"update_recently_accessed_media", self._update_recently_accessed
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_recently_accessed(self):
|
async def _update_recently_accessed(self) -> None:
|
||||||
remote_media = self.recently_accessed_remotes
|
remote_media = self.recently_accessed_remotes
|
||||||
self.recently_accessed_remotes = set()
|
self.recently_accessed_remotes = set()
|
||||||
|
|
||||||
|
@ -124,12 +127,12 @@ class MediaRepository:
|
||||||
local_media, remote_media, self.clock.time_msec()
|
local_media, remote_media, self.clock.time_msec()
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_recently_accessed(self, server_name, media_id):
|
def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
|
||||||
"""Mark the given media as recently accessed.
|
"""Mark the given media as recently accessed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str|None): Origin server of media, or None if local
|
server_name: Origin server of media, or None if local
|
||||||
media_id (str): The media ID of the content
|
media_id: The media ID of the content
|
||||||
"""
|
"""
|
||||||
if server_name:
|
if server_name:
|
||||||
self.recently_accessed_remotes.add((server_name, media_id))
|
self.recently_accessed_remotes.add((server_name, media_id))
|
||||||
|
@ -459,7 +462,14 @@ class MediaRepository:
|
||||||
def _get_thumbnail_requirements(self, media_type):
|
def _get_thumbnail_requirements(self, media_type):
|
||||||
return self.thumbnail_requirements.get(media_type, ())
|
return self.thumbnail_requirements.get(media_type, ())
|
||||||
|
|
||||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
|
def _generate_thumbnail(
|
||||||
|
self,
|
||||||
|
thumbnailer: Thumbnailer,
|
||||||
|
t_width: int,
|
||||||
|
t_height: int,
|
||||||
|
t_method: str,
|
||||||
|
t_type: str,
|
||||||
|
) -> Optional[BytesIO]:
|
||||||
m_width = thumbnailer.width
|
m_width = thumbnailer.width
|
||||||
m_height = thumbnailer.height
|
m_height = thumbnailer.height
|
||||||
|
|
||||||
|
@ -470,22 +480,20 @@ class MediaRepository:
|
||||||
m_height,
|
m_height,
|
||||||
self.max_image_pixels,
|
self.max_image_pixels,
|
||||||
)
|
)
|
||||||
return
|
return None
|
||||||
|
|
||||||
if thumbnailer.transpose_method is not None:
|
if thumbnailer.transpose_method is not None:
|
||||||
m_width, m_height = thumbnailer.transpose()
|
m_width, m_height = thumbnailer.transpose()
|
||||||
|
|
||||||
if t_method == "crop":
|
if t_method == "crop":
|
||||||
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
|
return thumbnailer.crop(t_width, t_height, t_type)
|
||||||
elif t_method == "scale":
|
elif t_method == "scale":
|
||||||
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
||||||
t_width = min(m_width, t_width)
|
t_width = min(m_width, t_width)
|
||||||
t_height = min(m_height, t_height)
|
t_height = min(m_height, t_height)
|
||||||
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
|
return thumbnailer.scale(t_width, t_height, t_type)
|
||||||
else:
|
|
||||||
t_byte_source = None
|
|
||||||
|
|
||||||
return t_byte_source
|
return None
|
||||||
|
|
||||||
async def generate_local_exact_thumbnail(
|
async def generate_local_exact_thumbnail(
|
||||||
self,
|
self,
|
||||||
|
@ -776,7 +784,7 @@ class MediaRepository:
|
||||||
|
|
||||||
return {"width": m_width, "height": m_height}
|
return {"width": m_width, "height": m_height}
|
||||||
|
|
||||||
async def delete_old_remote_media(self, before_ts):
|
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
|
||||||
old_media = await self.store.get_remote_media_before(before_ts)
|
old_media = await self.store.get_remote_media_before(before_ts)
|
||||||
|
|
||||||
deleted = 0
|
deleted = 0
|
||||||
|
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
|
||||||
within a given rectangle.
|
within a given rectangle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
# If we're not configured to use it, raise if we somehow got here.
|
# If we're not configured to use it, raise if we somehow got here.
|
||||||
if not hs.config.can_load_media_repo:
|
if not hs.config.can_load_media_repo:
|
||||||
raise ConfigError("Synapse is not configured to use a media repo.")
|
raise ConfigError("Synapse is not configured to use a media repo.")
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2018 New Vecotr Ltd
|
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -18,6 +18,8 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
||||||
|
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.interfaces import IConsumer
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
||||||
|
@ -270,7 +272,7 @@ class MediaStorage:
|
||||||
return self.filepaths.local_media_filepath_rel(file_info.file_id)
|
return self.filepaths.local_media_filepath_rel(file_info.file_id)
|
||||||
|
|
||||||
|
|
||||||
def _write_file_synchronously(source, dest):
|
def _write_file_synchronously(source: IO, dest: IO) -> None:
|
||||||
"""Write `source` to the file like `dest` synchronously. Should be called
|
"""Write `source` to the file like `dest` synchronously. Should be called
|
||||||
from a thread.
|
from a thread.
|
||||||
|
|
||||||
|
@ -286,14 +288,14 @@ class FileResponder(Responder):
|
||||||
"""Wraps an open file that can be sent to a request.
|
"""Wraps an open file that can be sent to a request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
open_file (file): A file like object to be streamed ot the client,
|
open_file: A file like object to be streamed ot the client,
|
||||||
is closed when finished streaming.
|
is closed when finished streaming.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, open_file):
|
def __init__(self, open_file: IO):
|
||||||
self.open_file = open_file
|
self.open_file = open_file
|
||||||
|
|
||||||
def write_to_consumer(self, consumer):
|
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
|
||||||
return make_deferred_yieldable(
|
return make_deferred_yieldable(
|
||||||
FileSender().beginFileTransfer(self.open_file, consumer)
|
FileSender().beginFileTransfer(self.open_file, consumer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,7 +13,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import errno
|
import errno
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
@ -23,12 +23,13 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
|
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||||
|
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from ._base import FileInfo
|
from ._base import FileInfo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lxml import etree
|
||||||
|
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
|
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
|
||||||
|
@ -107,7 +115,12 @@ class OEmbedError(Exception):
|
||||||
class PreviewUrlResource(DirectServeJsonResource):
|
class PreviewUrlResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo, media_storage):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
media_repo: "MediaRepository",
|
||||||
|
media_storage: MediaStorage,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -155,11 +168,11 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
self._start_expire_url_cache_data, 10 * 1000
|
self._start_expire_url_cache_data, 10 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_render_OPTIONS(self, request):
|
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request: Request) -> None:
|
||||||
|
|
||||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
@ -439,7 +452,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
||||||
raise OEmbedError() from e
|
raise OEmbedError() from e
|
||||||
|
|
||||||
async def _download_url(self, url: str, user):
|
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
|
||||||
# TODO: we should probably honour robots.txt... except in practice
|
# TODO: we should probably honour robots.txt... except in practice
|
||||||
# we're most likely being explicitly triggered by a human rather than a
|
# we're most likely being explicitly triggered by a human rather than a
|
||||||
# bot, so are we really a robot?
|
# bot, so are we really a robot?
|
||||||
|
@ -569,7 +582,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
"expire_url_cache_data", self._expire_url_cache_data
|
"expire_url_cache_data", self._expire_url_cache_data
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _expire_url_cache_data(self):
|
async def _expire_url_cache_data(self) -> None:
|
||||||
"""Clean up expired url cache content, media and thumbnails.
|
"""Clean up expired url cache content, media and thumbnails.
|
||||||
"""
|
"""
|
||||||
# TODO: Delete from backup media store
|
# TODO: Delete from backup media store
|
||||||
|
@ -665,7 +678,9 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
logger.debug("No media removed from url cache")
|
logger.debug("No media removed from url cache")
|
||||||
|
|
||||||
|
|
||||||
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
|
def decode_and_calc_og(
|
||||||
|
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
||||||
|
) -> Dict[str, Optional[str]]:
|
||||||
# If there's no body, nothing useful is going to be found.
|
# If there's no body, nothing useful is going to be found.
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
@ -686,7 +701,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
|
||||||
return og
|
return og
|
||||||
|
|
||||||
|
|
||||||
def _calc_og(tree, media_uri):
|
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
|
||||||
# suck our tree into lxml and define our OG response.
|
# suck our tree into lxml and define our OG response.
|
||||||
|
|
||||||
# if we see any image URLs in the OG response, then spider them
|
# if we see any image URLs in the OG response, then spider them
|
||||||
|
@ -790,7 +805,9 @@ def _calc_og(tree, media_uri):
|
||||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||||
)
|
)
|
||||||
og["og:description"] = summarize_paragraphs(text_nodes)
|
og["og:description"] = summarize_paragraphs(text_nodes)
|
||||||
else:
|
elif og["og:description"]:
|
||||||
|
# This must be a non-empty string at this point.
|
||||||
|
assert isinstance(og["og:description"], str)
|
||||||
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
||||||
|
|
||||||
# TODO: delete the url downloads to stop diskfilling,
|
# TODO: delete the url downloads to stop diskfilling,
|
||||||
|
@ -798,7 +815,9 @@ def _calc_og(tree, media_uri):
|
||||||
return og
|
return og
|
||||||
|
|
||||||
|
|
||||||
def _iterate_over_text(tree, *tags_to_ignore):
|
def _iterate_over_text(
|
||||||
|
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||||
skipping text nodes inside certain tags.
|
skipping text nodes inside certain tags.
|
||||||
"""
|
"""
|
||||||
|
@ -832,32 +851,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _rebase_url(url, base):
|
def _rebase_url(url: str, base: str) -> str:
|
||||||
base = list(urlparse.urlparse(base))
|
base_parts = list(urlparse.urlparse(base))
|
||||||
url = list(urlparse.urlparse(url))
|
url_parts = list(urlparse.urlparse(url))
|
||||||
if not url[0]: # fix up schema
|
if not url_parts[0]: # fix up schema
|
||||||
url[0] = base[0] or "http"
|
url_parts[0] = base_parts[0] or "http"
|
||||||
if not url[1]: # fix up hostname
|
if not url_parts[1]: # fix up hostname
|
||||||
url[1] = base[1]
|
url_parts[1] = base_parts[1]
|
||||||
if not url[2].startswith("/"):
|
if not url_parts[2].startswith("/"):
|
||||||
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
|
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
|
||||||
return urlparse.urlunparse(url)
|
return urlparse.urlunparse(url_parts)
|
||||||
|
|
||||||
|
|
||||||
def _is_media(content_type):
|
def _is_media(content_type: str) -> bool:
|
||||||
if content_type.lower().startswith("image/"):
|
return content_type.lower().startswith("image/")
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _is_html(content_type):
|
def _is_html(content_type: str) -> bool:
|
||||||
content_type = content_type.lower()
|
content_type = content_type.lower()
|
||||||
if content_type.startswith("text/html") or content_type.startswith(
|
return content_type.startswith("text/html") or content_type.startswith(
|
||||||
"application/xhtml"
|
"application/xhtml"
|
||||||
):
|
)
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
def summarize_paragraphs(
|
||||||
|
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||||
|
) -> Optional[str]:
|
||||||
# Try to get a summary of between 200 and 500 words, respecting
|
# Try to get a summary of between 200 and 500 words, respecting
|
||||||
# first paragraph and then word boundaries.
|
# first paragraph and then word boundaries.
|
||||||
# TODO: Respect sentences?
|
# TODO: Respect sentences?
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2018 New Vector Ltd
|
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,10 +13,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.config._base import Config
|
from synapse.config._base import Config
|
||||||
from synapse.logging.context import defer_to_thread, run_in_background
|
from synapse.logging.context import defer_to_thread, run_in_background
|
||||||
|
@ -27,13 +28,17 @@ from .media_storage import FileResponder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
class StorageProvider:
|
|
||||||
|
class StorageProvider(metaclass=abc.ABCMeta):
|
||||||
"""A storage provider is a service that can store uploaded media and
|
"""A storage provider is a service that can store uploaded media and
|
||||||
retrieve them.
|
retrieve them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def store_file(self, path: str, file_info: FileInfo):
|
@abc.abstractmethod
|
||||||
|
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||||
"""Store the file described by file_info. The actual contents can be
|
"""Store the file described by file_info. The actual contents can be
|
||||||
retrieved by reading the file in file_info.upload_path.
|
retrieved by reading the file in file_info.upload_path.
|
||||||
|
|
||||||
|
@ -42,6 +47,7 @@ class StorageProvider:
|
||||||
file_info: The metadata of the file.
|
file_info: The metadata of the file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||||
"""Attempt to fetch the file described by file_info and stream it
|
"""Attempt to fetch the file described by file_info and stream it
|
||||||
into writer.
|
into writer.
|
||||||
|
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
|
||||||
self.store_synchronous = store_synchronous
|
self.store_synchronous = store_synchronous
|
||||||
self.store_remote = store_remote
|
self.store_remote = store_remote
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "StorageProviderWrapper[%s]" % (self.backend,)
|
return "StorageProviderWrapper[%s]" % (self.backend,)
|
||||||
|
|
||||||
async def store_file(self, path, file_info):
|
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||||
if not file_info.server_name and not self.store_local:
|
if not file_info.server_name and not self.store_local:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
|
||||||
if self.store_synchronous:
|
if self.store_synchronous:
|
||||||
# store_file is supposed to return an Awaitable, but guard
|
# store_file is supposed to return an Awaitable, but guard
|
||||||
# against improper implementations.
|
# against improper implementations.
|
||||||
return await maybe_awaitable(self.backend.store_file(path, file_info))
|
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
|
||||||
else:
|
else:
|
||||||
# TODO: Handle errors.
|
# TODO: Handle errors.
|
||||||
async def store():
|
async def store():
|
||||||
|
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
|
||||||
logger.exception("Error storing file")
|
logger.exception("Error storing file")
|
||||||
|
|
||||||
run_in_background(store)
|
run_in_background(store)
|
||||||
return None
|
|
||||||
|
|
||||||
async def fetch(self, path, file_info):
|
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||||
# store_file is supposed to return an Awaitable, but guard
|
# store_file is supposed to return an Awaitable, but guard
|
||||||
# against improper implementations.
|
# against improper implementations.
|
||||||
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
||||||
|
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
|
||||||
"""A storage provider that stores files in a directory on a filesystem.
|
"""A storage provider that stores files in a directory on a filesystem.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (HomeServer)
|
hs
|
||||||
config: The config returned by `parse_config`.
|
config: The config returned by `parse_config`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, config):
|
def __init__(self, hs: "HomeServer", config: str):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.cache_directory = hs.config.media_store_path
|
self.cache_directory = hs.config.media_store_path
|
||||||
self.base_directory = config
|
self.base_directory = config
|
||||||
|
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
||||||
|
|
||||||
async def store_file(self, path, file_info):
|
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||||
"""See StorageProvider.store_file"""
|
"""See StorageProvider.store_file"""
|
||||||
|
|
||||||
primary_fname = os.path.join(self.cache_directory, path)
|
primary_fname = os.path.join(self.cache_directory, path)
|
||||||
|
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
return await defer_to_thread(
|
await defer_to_thread(
|
||||||
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
||||||
)
|
)
|
||||||
|
|
||||||
async def fetch(self, path, file_info):
|
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||||
"""See StorageProvider.fetch"""
|
"""See StorageProvider.fetch"""
|
||||||
|
|
||||||
backup_fname = os.path.join(self.base_directory, path)
|
backup_fname = os.path.join(self.base_directory, path)
|
||||||
if os.path.isfile(backup_fname):
|
if os.path.isfile(backup_fname):
|
||||||
return FileResponder(open(backup_fname, "rb"))
|
return FileResponder(open(backup_fname, "rb"))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(config):
|
def parse_config(config: dict) -> str:
|
||||||
"""Called on startup to parse config supplied. This should parse
|
"""Called on startup to parse config supplied. This should parse
|
||||||
the config and raise if there is a problem.
|
the config and raise if there is a problem.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -15,10 +16,14 @@
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
|
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||||
|
|
||||||
from ._base import (
|
from ._base import (
|
||||||
FileInfo,
|
FileInfo,
|
||||||
|
@ -28,13 +33,22 @@ from ._base import (
|
||||||
respond_with_responder,
|
respond_with_responder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ThumbnailResource(DirectServeJsonResource):
|
class ThumbnailResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo, media_storage):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
media_repo: "MediaRepository",
|
||||||
|
media_storage: MediaStorage,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request: Request) -> None:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
server_name, media_id, _ = parse_media_id(request)
|
server_name, media_id, _ = parse_media_id(request)
|
||||||
width = parse_integer(request, "width", required=True)
|
width = parse_integer(request, "width", required=True)
|
||||||
|
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
async def _respond_local_thumbnail(
|
async def _respond_local_thumbnail(
|
||||||
self, request, media_id, width, height, method, m_type
|
self,
|
||||||
):
|
request: Request,
|
||||||
|
media_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
method: str,
|
||||||
|
m_type: str,
|
||||||
|
) -> None:
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.store.get_local_media(media_id)
|
||||||
|
|
||||||
if not media_info:
|
if not media_info:
|
||||||
|
@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
|
|
||||||
async def _select_or_generate_local_thumbnail(
|
async def _select_or_generate_local_thumbnail(
|
||||||
self,
|
self,
|
||||||
request,
|
request: Request,
|
||||||
media_id,
|
media_id: str,
|
||||||
desired_width,
|
desired_width: int,
|
||||||
desired_height,
|
desired_height: int,
|
||||||
desired_method,
|
desired_method: str,
|
||||||
desired_type,
|
desired_type: str,
|
||||||
):
|
) -> None:
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.store.get_local_media(media_id)
|
||||||
|
|
||||||
if not media_info:
|
if not media_info:
|
||||||
|
@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
|
|
||||||
async def _select_or_generate_remote_thumbnail(
|
async def _select_or_generate_remote_thumbnail(
|
||||||
self,
|
self,
|
||||||
request,
|
request: Request,
|
||||||
server_name,
|
server_name: str,
|
||||||
media_id,
|
media_id: str,
|
||||||
desired_width,
|
desired_width: int,
|
||||||
desired_height,
|
desired_height: int,
|
||||||
desired_method,
|
desired_method: str,
|
||||||
desired_type,
|
desired_type: str,
|
||||||
):
|
) -> None:
|
||||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||||
|
@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||||
|
|
||||||
async def _respond_remote_thumbnail(
|
async def _respond_remote_thumbnail(
|
||||||
self, request, server_name, media_id, width, height, method, m_type
|
self,
|
||||||
):
|
request: Request,
|
||||||
|
server_name: str,
|
||||||
|
media_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
method: str,
|
||||||
|
m_type: str,
|
||||||
|
) -> None:
|
||||||
# TODO: Don't download the whole remote file
|
# TODO: Don't download the whole remote file
|
||||||
# We should proxy the thumbnail from the remote server instead of
|
# We should proxy the thumbnail from the remote server instead of
|
||||||
# downloading the remote file and generating our own thumbnails.
|
# downloading the remote file and generating our own thumbnails.
|
||||||
|
@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
|
|
||||||
def _select_thumbnail(
|
def _select_thumbnail(
|
||||||
self,
|
self,
|
||||||
desired_width,
|
desired_width: int,
|
||||||
desired_height,
|
desired_height: int,
|
||||||
desired_method,
|
desired_method: str,
|
||||||
desired_type,
|
desired_type: str,
|
||||||
thumbnail_infos,
|
thumbnail_infos,
|
||||||
):
|
) -> dict:
|
||||||
d_w = desired_width
|
d_w = desired_width
|
||||||
d_h = desired_height
|
d_h = desired_height
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue