Merge remote-tracking branch 'upstream/release-v1.24.0'

This commit is contained in:
Tulir Asokan 2020-12-02 16:30:50 +02:00
commit 6e2f942da1
152 changed files with 4450 additions and 2462 deletions

View file

@ -6,7 +6,7 @@
set -ex
apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8"

View file

@ -1,3 +1,75 @@
Synapse 1.24.0rc1 (2020-12-02)
==============================
Features
--------
- Add admin API for logging in as a user. ([\#8617](https://github.com/matrix-org/synapse/issues/8617))
- Allow specification of the SAML IdP if the metadata returns multiple IdPs. ([\#8630](https://github.com/matrix-org/synapse/issues/8630))
- Add support for re-trying generation of a localpart for OpenID Connect mapping providers. ([\#8801](https://github.com/matrix-org/synapse/issues/8801), [\#8855](https://github.com/matrix-org/synapse/issues/8855))
- Allow the `Date` header through CORS. Contributed by Nicolas Chamo. ([\#8804](https://github.com/matrix-org/synapse/issues/8804))
- Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages". ([\#8820](https://github.com/matrix-org/synapse/issues/8820))
- Add `force_purge` option to delete-room admin api. ([\#8843](https://github.com/matrix-org/synapse/issues/8843))
Bugfixes
--------
- Fix a bug where appservices may be sent an excessive amount of read receipts and presence. Broke in v1.22.0. ([\#8744](https://github.com/matrix-org/synapse/issues/8744))
- Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body. ([\#8776](https://github.com/matrix-org/synapse/issues/8776))
- Fix a bug where synctl could spawn duplicate copies of a worker. Contributed by Waylon Cude. ([\#8798](https://github.com/matrix-org/synapse/issues/8798))
- Allow per-room profiles to be used for the server notice user. ([\#8799](https://github.com/matrix-org/synapse/issues/8799))
- Fix a bug where logging could break after a call to SIGHUP. ([\#8817](https://github.com/matrix-org/synapse/issues/8817))
- Fix `register_new_matrix_user` failing with "Bad Request" when trailing slash is included in server URL. Contributed by @angdraug. ([\#8823](https://github.com/matrix-org/synapse/issues/8823))
- Fix a minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled. ([\#8835](https://github.com/matrix-org/synapse/issues/8835))
- Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication. ([\#8848](https://github.com/matrix-org/synapse/issues/8848))
- Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form. ([\#8784](https://github.com/matrix-org/synapse/issues/8784))
Improved Documentation
----------------------
- Clarify the usecase for a msisdn delegate. Contributed by Adrian Wannenmacher. ([\#8734](https://github.com/matrix-org/synapse/issues/8734))
- Remove extraneous comma from JSON example in User Admin API docs. ([\#8771](https://github.com/matrix-org/synapse/issues/8771))
- Update `turn-howto.md` with troubleshooting notes. ([\#8779](https://github.com/matrix-org/synapse/issues/8779))
- Fix the example on how to set the `Content-Type` header in nginx for the Client Well-Known URI. ([\#8793](https://github.com/matrix-org/synapse/issues/8793))
- Improve the documentation for the admin API to list all media in a room with respect to encrypted events. ([\#8795](https://github.com/matrix-org/synapse/issues/8795))
- Update the formatting of the `push` section of the homeserver config file to better align with the [code style guidelines](https://github.com/matrix-org/synapse/blob/develop/docs/code_style.md#configuration-file-format). ([\#8818](https://github.com/matrix-org/synapse/issues/8818))
- Improve documentation how to configure prometheus for workers. ([\#8822](https://github.com/matrix-org/synapse/issues/8822))
- Update example prometheus console. ([\#8824](https://github.com/matrix-org/synapse/issues/8824))
Deprecations and Removals
-------------------------
- Remove old `/_matrix/client/*/admin` endpoints which were deprecated since Synapse 1.20.0. ([\#8785](https://github.com/matrix-org/synapse/issues/8785))
- Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir. ([\#8833](https://github.com/matrix-org/synapse/issues/8833))
Internal Changes
----------------
- Simplify the way the `HomeServer` object caches its internal attributes. ([\#8565](https://github.com/matrix-org/synapse/issues/8565), [\#8851](https://github.com/matrix-org/synapse/issues/8851))
- Add an example and documentation for clock skew to the SAML2 sample configuration to allow for clock/time difference between the homserver and IdP. Contributed by @localguru. ([\#8731](https://github.com/matrix-org/synapse/issues/8731))
- Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite. ([\#8751](https://github.com/matrix-org/synapse/issues/8751))
- Generalise `RoomStore.maybe_store_room_on_invite` to handle other, non-invite membership events. ([\#8754](https://github.com/matrix-org/synapse/issues/8754))
- Refactor test utilities for injecting HTTP requests. ([\#8757](https://github.com/matrix-org/synapse/issues/8757), [\#8758](https://github.com/matrix-org/synapse/issues/8758), [\#8759](https://github.com/matrix-org/synapse/issues/8759), [\#8760](https://github.com/matrix-org/synapse/issues/8760), [\#8761](https://github.com/matrix-org/synapse/issues/8761), [\#8777](https://github.com/matrix-org/synapse/issues/8777))
- Consolidate logic between the OpenID Connect and SAML code. ([\#8765](https://github.com/matrix-org/synapse/issues/8765))
- Use `TYPE_CHECKING` instead of magic `MYPY` variable. ([\#8770](https://github.com/matrix-org/synapse/issues/8770))
- Add a commandline script to sign arbitrary json objects. ([\#8772](https://github.com/matrix-org/synapse/issues/8772))
- Minor log line improvements for the SSO mapping code used to generate Matrix IDs from SSO IDs. ([\#8773](https://github.com/matrix-org/synapse/issues/8773))
- Add additional error checking for OpenID Connect and SAML mapping providers. ([\#8774](https://github.com/matrix-org/synapse/issues/8774), [\#8800](https://github.com/matrix-org/synapse/issues/8800))
- Add type hints to HTTP abstractions. ([\#8806](https://github.com/matrix-org/synapse/issues/8806), [\#8812](https://github.com/matrix-org/synapse/issues/8812))
- Remove unnecessary function arguments and add typing to several membership replication classes. ([\#8809](https://github.com/matrix-org/synapse/issues/8809))
- Optimise the lookup for an invite from another homeserver when trying to reject it. ([\#8815](https://github.com/matrix-org/synapse/issues/8815))
- Add tests for `password_auth_provider`s. ([\#8819](https://github.com/matrix-org/synapse/issues/8819))
- Drop redundant database index on `event_json`. ([\#8845](https://github.com/matrix-org/synapse/issues/8845))
- Simplify `uk.half-shot.msc2778.login.application_service` login handler. ([\#8847](https://github.com/matrix-org/synapse/issues/8847))
- Refactor `password_auth_provider` support code. ([\#8849](https://github.com/matrix-org/synapse/issues/8849))
- Add missing `ordering` to background database updates. ([\#8850](https://github.com/matrix-org/synapse/issues/8850))
- Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`. ([\#8854](https://github.com/matrix-org/synapse/issues/8854))
Synapse 1.23.0 (2020-11-18)
===========================
@ -6322,8 +6394,8 @@ Changes in synapse 0.5.1 (2014-11-26)
See UPGRADES.rst for specific instructions on how to upgrade.
> - Fix bug where we served up an Event that did not match its signatures.
> - Fix regression where we no longer correctly handled the case where a homeserver receives an event for a room it doesn\'t recognise (but is in.)
- Fix bug where we served up an Event that did not match its signatures.
- Fix regression where we no longer correctly handled the case where a homeserver receives an event for a room it doesn\'t recognise (but is in.)
Changes in synapse 0.5.0 (2014-11-19)
=====================================
@ -6334,7 +6406,7 @@ This release also changes the internal database schemas and so requires servers
Homeserver:
: - Add authentication and authorization to the federation protocol. Events are now signed by their originating homeservers.
- Add authentication and authorization to the federation protocol. Events are now signed by their originating homeservers.
- Implement the new authorization model for rooms.
- Split out web client into a seperate repository: matrix-angular-sdk.
- Change the structure of PDUs.
@ -6346,19 +6418,19 @@ Homeserver:
Webclient:
: - The webclient has been moved to a seperate repository.
- The webclient has been moved to a seperate repository.
Changes in synapse 0.4.2 (2014-10-31)
=====================================
Homeserver:
: - Fix bugs where we did not notify users of correct presence updates.
- Fix bugs where we did not notify users of correct presence updates.
- Fix bug where we did not handle sub second event stream timeouts.
Webclient:
: - Add ability to click on messages to see JSON.
- Add ability to click on messages to see JSON.
- Add ability to redact messages.
- Add ability to view and edit all room state JSON.
- Handle incoming redactions.
@ -6371,7 +6443,7 @@ Changes in synapse 0.4.1 (2014-10-17)
Webclient:
: - Fix bug with display of timestamps.
- Fix bug with display of timestamps.
Changes in synpase 0.4.0 (2014-10-17)
=====================================
@ -6384,7 +6456,7 @@ You will also need an updated syutil and config. See UPGRADES.rst.
Homeserver:
: - Sign federation transactions to assert strong identity over federation.
- Sign federation transactions to assert strong identity over federation.
- Rename timestamp keys in PDUs and events from \'ts\' and \'hsob\_ts\' to \'origin\_server\_ts\'.
Changes in synapse 0.3.4 (2014-09-25)
@ -6394,14 +6466,14 @@ This version adds support for using a TURN server. See docs/turn-howto.rst on ho
Homeserver:
: - Add support for redaction of messages.
- Add support for redaction of messages.
- Fix bug where inviting a user on a remote home server could take up to 20-30s.
- Implement a get current room state API.
- Add support specifying and retrieving turn server configuration.
Webclient:
: - Add button to send messages to users from the home page.
- Add button to send messages to users from the home page.
- Add support for using TURN for VoIP calls.
- Show display name change messages.
- Fix bug where the client didn\'t get the state of a newly joined room until after it has been refreshed.
@ -6416,11 +6488,11 @@ Changes in synapse 0.3.3 (2014-09-22)
Homeserver:
: - Fix bug where you continued to get events for rooms you had left.
- Fix bug where you continued to get events for rooms you had left.
Webclient:
: - Add support for video calls with basic UI.
- Add support for video calls with basic UI.
- Fix bug where one to one chats were named after your display name rather than the other person\'s.
- Fix bug which caused lag when typing in the textarea.
- Refuse to run on browsers we know won\'t work.
@ -6435,7 +6507,7 @@ Changes in synapse 0.3.2 (2014-09-18)
Webclient:
: - Fix bug where an empty \"bing words\" list in old accounts didn\'t send notifications when it should have done.
- Fix bug where an empty \"bing words\" list in old accounts didn\'t send notifications when it should have done.
Changes in synapse 0.3.1 (2014-09-18)
=====================================
@ -6444,7 +6516,7 @@ This is a release to hotfix v0.3.0 to fix two regressions.
Webclient:
: - Fix a regression where we sometimes displayed duplicate events.
- Fix a regression where we sometimes displayed duplicate events.
- Fix a regression where we didn\'t immediately remove rooms you were banned in from the recents list.
Changes in synapse 0.3.0 (2014-09-18)
@ -6454,7 +6526,7 @@ See UPGRADE for information about changes to the client server API, including br
Homeserver:
: - When a user changes their displayname or avatar the server will now update all their join states to reflect this.
- When a user changes their displayname or avatar the server will now update all their join states to reflect this.
- The server now adds \"age\" key to events to indicate how old they are. This is clock independent, so at no point does any server or webclient have to assume their clock is in sync with everyone else.
- Fix bug where we didn\'t correctly pull in missing PDUs.
- Fix bug where prev\_content key wasn\'t always returned.
@ -6462,7 +6534,7 @@ Homeserver:
Webclient:
: - Improve page content loading.
- Improve page content loading.
- Join/parts now trigger desktop notifications.
- Always show room aliases in the UI if one is present.
- No longer show user-count in the recents side panel.
@ -6475,7 +6547,7 @@ Webclient:
Registration API:
: - The registration API has been overhauled to function like the login API. In practice, this means registration requests must now include the following: \'type\':\'m.login.password\'. See UPGRADE for more information on this.
- The registration API has been overhauled to function like the login API. In practice, this means registration requests must now include the following: \'type\':\'m.login.password\'. See UPGRADE for more information on this.
- The \'user\_id\' key has been renamed to \'user\' to better match the login API.
- There is an additional login type: \'m.login.email.identity\'.
- The command client and web client have been updated to reflect these changes.
@ -6485,12 +6557,12 @@ Changes in synapse 0.2.3 (2014-09-12)
Homeserver:
: - Fix bug where we stopped sending events to remote home servers if a user from that home server left, even if there were some still in the room.
- Fix bug where we stopped sending events to remote home servers if a user from that home server left, even if there were some still in the room.
- Fix bugs in the state conflict resolution where it was incorrectly rejecting events.
Webclient:
: - Display room names and topics.
- Display room names and topics.
- Allow setting/editing of room names and topics.
- Display information about rooms on the main page.
- Handle ban and kick events in real time.
@ -6507,14 +6579,14 @@ Changes in synapse 0.2.2 (2014-09-06)
Homeserver:
: - When the server returns state events it now also includes the previous content.
- When the server returns state events it now also includes the previous content.
- Add support for inviting people when creating a new room.
- Make the homeserver inform the room via m.room.aliases when a new alias is added for a room.
- Validate m.room.power\_level events.
Webclient:
: - Add support for captchas on registration.
- Add support for captchas on registration.
- Handle m.room.aliases events.
- Asynchronously send messages and show a local echo.
- Inform the UI when a message failed to send.
@ -6526,7 +6598,7 @@ Changes in synapse 0.2.1 (2014-09-03)
Homeserver:
: - Added support for signing up with a third party id.
- Added support for signing up with a third party id.
- Add synctl scripts.
- Added rate limiting.
- Add option to change the external address the content repo uses.
@ -6534,7 +6606,7 @@ Homeserver:
Webclient:
: - Added support for signing up with a third party id.
- Added support for signing up with a third party id.
- Added support for banning and kicking users.
- Added support for displaying and setting ops.
- Added support for room names.
@ -6547,7 +6619,7 @@ This update changes many configuration options, updates the database schema and
Homeserver:
: - Require SSL for server-server connections.
- Require SSL for server-server connections.
- Add SSL listener for client-server connections.
- Add ability to use config files.
- Add support for kicking/banning and power levels.
@ -6558,7 +6630,7 @@ Homeserver:
Webclient:
: - Reskin the CSS for registration and login.
- Reskin the CSS for registration and login.
- Various improvements to rooms CSS.
- Support changes in client-server API.
- Bug fixes to VOIP UI.
@ -6569,14 +6641,14 @@ Changes in synapse 0.1.2 (2014-08-29)
Webclient:
: - Add basic call state UI for VoIP calls.
- Add basic call state UI for VoIP calls.
Changes in synapse 0.1.1 (2014-08-29)
=====================================
Homeserver:
: - Fix bug that caused the event stream to not notify some clients about changes.
- Fix bug that caused the event stream to not notify some clients about changes.
Changes in synapse 0.1.0 (2014-08-29)
=====================================
@ -6585,20 +6657,16 @@ Presence has been reenabled in this release.
Homeserver:
: -
Update client to server API, including:
: - Use a more consistent url scheme.
- Update client to server API, including:
- Use a more consistent url scheme.
- Provide more useful information in the initial sync api.
- Change the presence handling to be much more efficient.
- Change the presence server to server API to not require explicit polling of all users who share a room with a user.
- Fix races in the event streaming logic.
Webclient:
: - Update to use new client to server API.
- Update to use new client to server API.
- Add basic VOIP support.
- Add idle timers that change your status to away.
- Add recent rooms column when viewing a room.
@ -6613,7 +6681,7 @@ Presence has been disabled in this release due to a bug that caused the homeserv
Homeserver:
: - Completely change the database schema to support generic event types.
- Completely change the database schema to support generic event types.
- Improve presence reliability.
- Improve reliability of joining remote rooms.
- Fix bug where room join events were duplicated.
@ -6622,7 +6690,7 @@ Homeserver:
Webclient:
: - Add tab completion of names.
- Add tab completion of names.
- Add ability to upload and send images.
- Add profile pages.
- Improve CSS layout of room.
@ -6635,4 +6703,4 @@ Webclient:
Changes in synapse 0.0.0 (2014-08-13)
=====================================
> - Initial alpha release
- Initial alpha release

View file

@ -487,7 +487,7 @@ In nginx this would be something like:
```
location /.well-known/matrix/client {
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
add_header Content-Type application/json;
default_type application/json;
add_header Access-Control-Allow-Origin *;
}
```

View file

@ -75,6 +75,58 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.24.0
====================
Custom OpenID Connect mapping provider breaking change
------------------------------------------------------
This release allows the OpenID Connect mapping provider to perform normalisation
of the localpart of the Matrix ID. This allows for the mapping provider to
specify different algorithms, instead of the [default way](https://matrix.org/docs/spec/appendices#mapping-from-other-character-sets).
If your Synapse configuration uses a custom mapping provider
(`oidc_config.user_mapping_provider.module` is specified and not equal to
`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`) then you *must* ensure
that `map_user_attributes` of the mapping provider performs some normalisation
of the `localpart` returned. To match previous behaviour you can use the
`map_username_to_mxid_localpart` function provided by Synapse. An example is
shown below:
.. code-block:: python
from synapse.types import map_username_to_mxid_localpart
class MyMappingProvider:
def map_user_attributes(self, userinfo, token):
# ... your custom logic ...
sso_user_id = ...
localpart = map_username_to_mxid_localpart(sso_user_id)
return {"localpart": localpart}
Removal historical Synapse Admin API
------------------------------------
Historically, the Synapse Admin API has been accessible under:
* ``/_matrix/client/api/v1/admin``
* ``/_matrix/client/unstable/admin``
* ``/_matrix/client/r0/admin``
* ``/_synapse/admin/v1``
The endpoints with ``/_matrix/client/*`` prefixes have been removed as of v1.24.0.
The Admin API is now only accessible under:
* ``/_synapse/admin/v1``
The only exception is the `/admin/whois` endpoint, which is
`also available via the client-server API <https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid>`_.
The deprecation of the old endpoints was announced with Synapse 1.20.0 (released
on 2020-09-22) and makes it easier for homeserver admins to lock down external
access to the Admin API endpoints.
Upgrading to v1.23.0
====================

View file

@ -20,6 +20,7 @@ Add a new job to the main prometheus.conf file:
```
### for Prometheus v2
Add a new job to the main prometheus.yml file:
```yaml
@ -29,9 +30,12 @@ Add a new job to the main prometheus.yml file:
scheme: "https"
static_configs:
- targets: ['SERVER.LOCATION:PORT']
- targets: ["my.server.here:port"]
```
An example of a Prometheus configuration with workers can be found in
[metrics-howto.md](https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md).
To use `synapse.rules` add
```yaml

View file

@ -9,7 +9,7 @@
new PromConsole.Graph({
node: document.querySelector("#process_resource_utime"),
expr: "rate(process_cpu_seconds_total[2m]) * 100",
name: "[[job]]",
name: "[[job]]-[[index]]",
min: 0,
max: 100,
renderer: "line",
@ -22,12 +22,12 @@ new PromConsole.Graph({
</script>
<h3>Memory</h3>
<div id="process_resource_maxrss"></div>
<div id="process_resident_memory_bytes"></div>
<script>
new PromConsole.Graph({
node: document.querySelector("#process_resource_maxrss"),
expr: "process_psutil_rss:max",
name: "Maxrss",
node: document.querySelector("#process_resident_memory_bytes"),
expr: "process_resident_memory_bytes",
name: "[[job]]-[[index]]",
min: 0,
renderer: "line",
height: 150,
@ -43,8 +43,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#process_fds"),
expr: "process_open_fds{job='synapse'}",
name: "FDs",
expr: "process_open_fds",
name: "[[job]]-[[index]]",
min: 0,
renderer: "line",
height: 150,
@ -62,8 +62,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#reactor_total_time"),
expr: "rate(python_twisted_reactor_tick_time:total[2m]) / 1000",
name: "time",
expr: "rate(python_twisted_reactor_tick_time_sum[2m])",
name: "[[job]]-[[index]]",
max: 1,
min: 0,
renderer: "area",
@ -80,8 +80,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#reactor_average_time"),
expr: "rate(python_twisted_reactor_tick_time:total[2m]) / rate(python_twisted_reactor_tick_time:count[2m]) / 1000",
name: "time",
expr: "rate(python_twisted_reactor_tick_time_sum[2m]) / rate(python_twisted_reactor_tick_time_count[2m])",
name: "[[job]]-[[index]]",
min: 0,
renderer: "line",
height: 150,
@ -97,14 +97,14 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#reactor_pending_calls"),
expr: "rate(python_twisted_reactor_pending_calls:total[30s])/rate(python_twisted_reactor_pending_calls:count[30s])",
name: "calls",
expr: "rate(python_twisted_reactor_pending_calls_sum[30s]) / rate(python_twisted_reactor_pending_calls_count[30s])",
name: "[[job]]-[[index]]",
min: 0,
renderer: "line",
height: 150,
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yTitle: "Pending Cals"
yTitle: "Pending Calls"
})
</script>
@ -115,7 +115,7 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_storage_query_time"),
expr: "rate(synapse_storage_query_time:count[2m])",
expr: "sum(rate(synapse_storage_query_time_count[2m])) by (verb)",
name: "[[verb]]",
yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
@ -129,8 +129,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_storage_transaction_time"),
expr: "rate(synapse_storage_transaction_time:count[2m])",
name: "[[desc]]",
expr: "topk(10, rate(synapse_storage_transaction_time_count[2m]))",
name: "[[job]]-[[index]] [[desc]]",
min: 0,
yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
@ -140,12 +140,12 @@ new PromConsole.Graph({
</script>
<h3>Transaction execution time</h3>
<div id="synapse_storage_transactions_time_msec"></div>
<div id="synapse_storage_transactions_time_sec"></div>
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_storage_transactions_time_msec"),
expr: "rate(synapse_storage_transaction_time:total[2m]) / 1000",
name: "[[desc]]",
node: document.querySelector("#synapse_storage_transactions_time_sec"),
expr: "rate(synapse_storage_transaction_time_sum[2m])",
name: "[[job]]-[[index]] [[desc]]",
min: 0,
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
@ -154,34 +154,33 @@ new PromConsole.Graph({
})
</script>
<h3>Database scheduling latency</h3>
<div id="synapse_storage_schedule_time"></div>
<h3>Average time waiting for database connection</h3>
<div id="synapse_storage_avg_waiting_time"></div>
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_storage_schedule_time"),
expr: "rate(synapse_storage_schedule_time:total[2m]) / 1000",
name: "Total latency",
node: document.querySelector("#synapse_storage_avg_waiting_time"),
expr: "rate(synapse_storage_schedule_time_sum[2m]) / rate(synapse_storage_schedule_time_count[2m])",
name: "[[job]]-[[index]]",
min: 0,
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "s/s",
yTitle: "Usage"
yUnits: "s",
yTitle: "Time"
})
</script>
<h3>Cache hit ratio</h3>
<div id="synapse_cache_ratio"></div>
<h3>Cache request rate</h3>
<div id="synapse_cache_request_rate"></div>
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_cache_ratio"),
expr: "rate(synapse_util_caches_cache:total[2m]) * 100",
name: "[[name]]",
node: document.querySelector("#synapse_cache_request_rate"),
expr: "rate(synapse_util_caches_cache:total[2m])",
name: "[[job]]-[[index]] [[name]]",
min: 0,
max: 100,
yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yUnits: "%",
yTitle: "Percentage"
yUnits: "rps",
yTitle: "Cache request rate"
})
</script>
@ -191,7 +190,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_cache_size"),
expr: "synapse_util_caches_cache:size",
name: "[[name]]",
name: "[[job]]-[[index]] [[name]]",
yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yUnits: "",
@ -206,8 +205,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_http_server_request_count_servlet"),
expr: "rate(synapse_http_server_request_count:servlet[2m])",
name: "[[servlet]]",
expr: "rate(synapse_http_server_in_flight_requests_count[2m])",
name: "[[job]]-[[index]] [[method]] [[servlet]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "req/s",
@ -219,8 +218,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_http_server_request_count_servlet_minus_events"),
expr: "rate(synapse_http_server_request_count:servlet{servlet!=\"EventStreamRestServlet\", servlet!=\"SyncRestServlet\"}[2m])",
name: "[[servlet]]",
expr: "rate(synapse_http_server_in_flight_requests_count{servlet!=\"EventStreamRestServlet\", servlet!=\"SyncRestServlet\"}[2m])",
name: "[[job]]-[[index]] [[method]] [[servlet]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "req/s",
@ -233,8 +232,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_http_server_response_time_avg"),
expr: "rate(synapse_http_server_response_time_seconds[2m]) / rate(synapse_http_server_response_count[2m]) / 1000",
name: "[[servlet]]",
expr: "rate(synapse_http_server_response_time_seconds_sum[2m]) / rate(synapse_http_server_response_count[2m])",
name: "[[job]]-[[index]] [[servlet]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "s/req",
@ -277,7 +276,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_http_server_response_ru_utime"),
expr: "rate(synapse_http_server_response_ru_utime_seconds[2m])",
name: "[[servlet]]",
name: "[[job]]-[[index]] [[servlet]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "s/s",
@ -292,7 +291,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_http_server_response_db_txn_duration"),
expr: "rate(synapse_http_server_response_db_txn_duration_seconds[2m])",
name: "[[servlet]]",
name: "[[job]]-[[index]] [[servlet]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "s/s",
@ -306,8 +305,8 @@ new PromConsole.Graph({
<script>
new PromConsole.Graph({
node: document.querySelector("#synapse_http_server_send_time_avg"),
expr: "rate(synapse_http_server_response_time_second{servlet='RoomSendEventRestServlet'}[2m]) / rate(synapse_http_server_response_count{servlet='RoomSendEventRestServlet'}[2m]) / 1000",
name: "[[servlet]]",
expr: "rate(synapse_http_server_response_time_seconds_sum{servlet='RoomSendEventRestServlet'}[2m]) / rate(synapse_http_server_response_count{servlet='RoomSendEventRestServlet'}[2m])",
name: "[[job]]-[[index]] [[servlet]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "s/req",
@ -323,7 +322,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_federation_client_sent"),
expr: "rate(synapse_federation_client_sent[2m])",
name: "[[type]]",
name: "[[job]]-[[index]] [[type]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "req/s",
@ -337,7 +336,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_federation_server_received"),
expr: "rate(synapse_federation_server_received[2m])",
name: "[[type]]",
name: "[[job]]-[[index]] [[type]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "req/s",
@ -367,7 +366,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_notifier_listeners"),
expr: "synapse_notifier_listeners",
name: "listeners",
name: "[[job]]-[[index]]",
min: 0,
yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix,
@ -382,7 +381,7 @@ new PromConsole.Graph({
new PromConsole.Graph({
node: document.querySelector("#synapse_notifier_notified_events"),
expr: "rate(synapse_notifier_notified_events[2m])",
name: "events",
name: "[[job]]-[[index]]",
yAxisFormatter: PromConsole.NumberFormatter.humanize,
yHoverFormatter: PromConsole.NumberFormatter.humanize,
yUnits: "events/s",

View file

@ -1,6 +1,7 @@
# List all media in a room
This API gets a list of known media in a room.
However, it only shows media from unencrypted events or rooms.
The API is:
```

View file

@ -382,7 +382,7 @@ the new room. Users on other servers will be unaffected.
The API is:
```json
```
POST /_synapse/admin/v1/rooms/<room_id>/delete
```
@ -439,6 +439,10 @@ The following JSON body parameters are available:
future attempts to join the room. Defaults to `false`.
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
Defaults to `true`.
* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it
will force a purge to go ahead even if there are local users still in the room. Do not
use this unless a regular `purge` operation fails, as it could leave those users'
clients in a confused state.
The JSON body must not be empty. The body must be at least `{}`.

View file

@ -176,6 +176,13 @@ The api is::
GET /_synapse/admin/v1/whois/<user_id>
and::
GET /_matrix/client/r0/admin/whois/<userId>
See also: `Client Server API Whois
<https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid>`_
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
@ -254,7 +261,7 @@ with a body of:
{
"new_password": "<secret>",
"logout_devices": true,
"logout_devices": true
}
To use it, you will need to authenticate by providing an ``access_token`` for a
@ -424,6 +431,41 @@ The following fields are returned in the JSON response body:
- ``next_token``: integer - Indication for pagination. See above.
- ``total`` - integer - Total number of media.
Login as a user
===============
Get an access token that can be used to authenticate as that user. Useful for
when admins wish to do actions on behalf of a user.
The API is::
POST /_synapse/admin/v1/users/<user_id>/login
{}
An optional ``valid_until_ms`` field can be specified in the request body as an
integer timestamp that specifies when the token should expire. By default tokens
do not expire.
A response body like the following is returned:
.. code:: json
{
"access_token": "<opaque_access_token_string>"
}
This API does *not* generate a new device for the user, and so will not appear
their ``/devices`` list, and in general the target user should not be able to
tell they have been logged in as.
To expire the token call the standard ``/logout`` API with the token.
Note: The token will expire if the *admin* user calls ``/logout/all`` from any
of their devices, but the token will *not* expire if the target user does the
same.
User devices
============

View file

@ -13,10 +13,12 @@
can be enabled by adding the \"metrics\" resource to the existing
listener as such:
```yaml
resources:
- names:
- client
- metrics
```
This provides a simple way of adding metrics to your Synapse
installation, and serves under `/_synapse/metrics`. If you do not
@ -31,11 +33,13 @@
Add a new listener to homeserver.yaml:
```yaml
listeners:
- type: metrics
port: 9000
bind_addresses:
- '0.0.0.0'
```
For both options, you will need to ensure that `enable_metrics` is
set to `True`.
@ -47,10 +51,13 @@
It needs to set the `metrics_path` to a non-default value (under
`scrape_configs`):
```yaml
- job_name: "synapse"
scrape_interval: 15s
metrics_path: "/_synapse/metrics"
static_configs:
- targets: ["my.server.here:port"]
```
where `my.server.here` is the IP address of Synapse, and `port` is
the listener port configured with the `metrics` resource.
@ -60,7 +67,8 @@
1. Restart Prometheus.
1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/)
and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
## Monitoring workers
@ -87,6 +95,38 @@ don't clash with an existing worker.
With this example, the worker's metrics would then be available
on `http://127.0.0.1:9101`.
Example Prometheus target for Synapse with workers:
```yaml
- job_name: "synapse"
scrape_interval: 15s
metrics_path: "/_synapse/metrics"
static_configs:
- targets: ["my.server.here:port"]
labels:
instance: "my.server"
job: "master"
index: 1
- targets: ["my.workerserver.here:port"]
labels:
instance: "my.server"
job: "generic_worker"
index: 1
- targets: ["my.workerserver.here:port"]
labels:
instance: "my.server"
job: "generic_worker"
index: 2
- targets: ["my.workerserver.here:port"]
labels:
instance: "my.server"
job: "media_repository"
index: 1
```
Labels (`instance`, `job`, `index`) can be defined as anything.
The labels are used to group graphs in grafana.
## Renaming of metrics & deprecation of old names in 1.2
Synapse 1.2 updates the Prometheus metrics to match the naming

View file

@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods:
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into
`__init__`.
This method should have the `@staticmethod` decoration.

View file

@ -1230,8 +1230,9 @@ account_validity:
# email will be globally disabled.
#
# Additionally, if `msisdn` is not set, registration and password resets via msisdn
# will be disabled regardless. This is due to Synapse currently not supporting any
# method of sending SMS messages on its own.
# will be disabled regardless, and users will not be able to associate an msisdn
# identifier to their account. This is due to Synapse currently not supporting
# any method of sending SMS messages on its own.
#
# To enable using an identity server for operations regarding a particular third-party
# identifier type, set the value to the URL of that identity server as shown in the
@ -1545,6 +1546,12 @@ saml2_config:
# remote:
# - url: https://our_idp/metadata.xml
# Allowed clock difference in seconds between the homeserver and IdP.
#
# Uncomment the below to increase the accepted time difference from 0 to 3 seconds.
#
#accepted_time_diff: 3
# By default, the user has to go to our login page first. If you'd like
# to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# 'service.sp' section:
@ -1667,6 +1674,14 @@ saml2_config:
# - attribute: department
# value: "sales"
# If the metadata XML contains multiple IdP entities then the `idp_entityid`
# option must be set to the entity to redirect users to.
#
# Most deployments only have a single IdP entity and so should omit this
# option.
#
#idp_entityid: 'https://our_idp/entityid'
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
#
@ -2236,6 +2251,9 @@ password_providers:
## Push ##
push:
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details
# like the sender, or just the event ID and room ID (`event_id_only`).
@ -2248,8 +2266,20 @@ password_providers:
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
# include_content: true
# The default value is "true" to include message details. Uncomment to only
# include the event ID and room ID in push notification payloads.
#
#include_content: false
# When a push notification is received, an unread count is also sent.
# This number can either be calculated as the number of unread messages
# for the user, or the number of *rooms* the user has unread messages in.
#
# The default value is "true", meaning push clients will see the number of
# rooms with unread messages in them. Uncomment to instead send the number
# of unread messages.
#
#group_unread_count_by_room: false
# Spam checkers are third-party modules that can block specific actions

View file

@ -15,8 +15,15 @@ where SAML mapping providers come into play.
SSO mapping providers are currently supported for OpenID and SAML SSO
configurations. Please see the details below for how to implement your own.
It is the responsibility of the mapping provider to normalise the SSO attributes
and map them to a valid Matrix ID. The
[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers)
has some information about what is considered valid. Alternately an easy way to
ensure it is valid is to use a Synapse utility function:
`synapse.types.map_username_to_mxid_localpart`.
External mapping providers are provided to Synapse in the form of an external
Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere,
Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere,
but it must be importable via Synapse (e.g. it must be in the same virtualenv
as Synapse). The Synapse config is then modified to point to the mapping provider
(and optionally provide additional configuration for it).
@ -56,13 +63,22 @@ A custom mapping provider must specify the following methods:
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
* `map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.

View file

@ -42,10 +42,10 @@ This will install and start a systemd service called `coturn`.
./configure
> You may need to install `libevent2`: if so, you should do so in
> the way recommended by your operating system. You can ignore
> warnings about lack of database support: a database is unnecessary
> for this purpose.
You may need to install `libevent2`: if so, you should do so in
the way recommended by your operating system. You can ignore
warnings about lack of database support: a database is unnecessary
for this purpose.
1. Build and install it:
@ -66,6 +66,19 @@ This will install and start a systemd service called `coturn`.
pwgen -s 64 1
A `realm` must be specified, but its value is somewhat arbitrary. (It is
sent to clients as part of the authentication flow.) It is conventional to
set it to be your server name.
1. You will most likely want to configure coturn to write logs somewhere. The
easiest way is normally to send them to the syslog:
syslog
(in which case, the logs will be available via `journalctl -u coturn` on a
systemd system). Alternatively, coturn can be configured to write to a
logfile - check the example config file supplied with coturn.
1. Consider your security settings. TURN lets users request a relay which will
connect to arbitrary IP addresses and ports. The following configuration is
suggested as a minimum starting point:
@ -96,11 +109,31 @@ This will install and start a systemd service called `coturn`.
# TLS private key file
pkey=/path/to/privkey.pem
In this case, replace the `turn:` schemes in the `turn_uri` settings below
with `turns:`.
We recommend that you only try to set up TLS/DTLS once you have set up a
basic installation and got it working.
1. Ensure your firewall allows traffic into the TURN server on the ports
you've configured it to listen on (By default: 3478 and 5349 for the TURN(s)
you've configured it to listen on (By default: 3478 and 5349 for TURN
traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
for the UDP relay.)
1. We do not recommend running a TURN server behind NAT, and are not aware of
anyone doing so successfully.
If you want to try it anyway, you will at least need to tell coturn its
external IP address:
external-ip=192.88.99.1
... and your NAT gateway must forward all of the relayed ports directly
(eg, port 56789 on the external IP must be always be forwarded to port
56789 on the internal IP).
If you get this working, let us know!
1. (Re)start the turn server:
* If you used the Debian package (or have set up a systemd unit yourself):
@ -137,9 +170,10 @@ Your home server configuration file needs the following extra keys:
without having gone through a CAPTCHA or similar to register a
real account.
As an example, here is the relevant section of the config file for matrix.org:
As an example, here is the relevant section of the config file for `matrix.org`. The
`turn_uris` are appropriate for TURN servers listening on the default ports, with no TLS.
turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
turn_uris: [ "turn:turn.matrix.org?transport=udp", "turn:turn.matrix.org?transport=tcp" ]
turn_shared_secret: "n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons"
turn_user_lifetime: 86400000
turn_allow_guests: True
@ -155,5 +189,86 @@ After updating the homeserver configuration, you must restart synapse:
```
systemctl restart synapse.service
```
... and then reload any clients (or wait an hour for them to refresh their
settings).
..and your Home Server now supports VoIP relaying!
## Troubleshooting
The normal symptoms of a misconfigured TURN server are that calls between
devices on different networks ring, but get stuck at "call
connecting". Unfortunately, troubleshooting this can be tricky.
Here are a few things to try:
* Check that your TURN server is not behind NAT. As above, we're not aware of
anyone who has successfully set this up.
* Check that you have opened your firewall to allow TCP and UDP traffic to the
TURN ports (normally 3478 and 5479).
* Check that you have opened your firewall to allow UDP traffic to the UDP
relay ports (49152-65535 by default).
* Some WebRTC implementations (notably, that of Google Chrome) appear to get
confused by TURN servers which are reachable over IPv6 (this appears to be
an unexpected side-effect of its handling of multiple IP addresses as
defined by
[`draft-ietf-rtcweb-ip-handling`](https://tools.ietf.org/html/draft-ietf-rtcweb-ip-handling-12)).
Try removing any AAAA records for your TURN server, so that it is only
reachable over IPv4.
* Enable more verbose logging in coturn via the `verbose` setting:
```
verbose
```
... and then see if there are any clues in its logs.
* If you are using a browser-based client under Chrome, check
`chrome://webrtc-internals/` for insights into the internals of the
negotiation. On Firefox, check the "Connection Log" on `about:webrtc`.
(Understanding the output is beyond the scope of this document!)
* There is a WebRTC test tool at
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
use it, you will need a username/password for your TURN server. You can
either:
* look for the `GET /_matrix/client/r0/voip/turnServer` request made by a
matrix client to your homeserver in your browser's network inspector. In
the response you should see `username` and `password`. Or:
* Use the following shell commands:
```sh
secret=staticAuthSecretHere
u=$((`date +%s` + 3600)):test
p=$(echo -n $u | openssl dgst -hmac $secret -sha1 -binary | base64)
echo -e "username: $u\npassword: $p"
```
Or:
* Temporarily configure coturn to accept a static username/password. To do
this, comment out `use-auth-secret` and `static-auth-secret` and add the
following:
```
lt-cred-mech
user=username:password
```
**Note**: these settings will not take effect unless `use-auth-secret`
and `static-auth-secret` are disabled.
Restart coturn after changing the configuration file.
Remember to restore the original settings to go back to testing with
Matrix clients!
If the TURN server is working correctly, you should see at least one `relay`
entry in the results.

View file

@ -8,6 +8,7 @@ show_traceback = True
mypy_path = stubs
warn_unreachable = True
files =
scripts-dev/sign_json,
synapse/api,
synapse/appservice,
synapse/config,
@ -37,13 +38,17 @@ files =
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/register.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging,
@ -75,6 +80,7 @@ files =
synapse/util/metrics.py,
tests/replication,
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py
@ -105,7 +111,7 @@ ignore_missing_imports = True
[mypy-opentracing]
ignore_missing_imports = True
[mypy-OpenSSL]
[mypy-OpenSSL.*]
ignore_missing_imports = True
[mypy-netaddr]

127
scripts-dev/sign_json Executable file
View file

@ -0,0 +1,127 @@
#!/usr/bin/env python
#
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import sys
from json import JSONDecodeError
import yaml
from signedjson.key import read_signing_keys
from signedjson.sign import sign_json
from synapse.util import json_encoder
def main():
parser = argparse.ArgumentParser(
description="""Adds a signature to a JSON object.
Example usage:
$ scripts-dev/sign_json.py -N test -k localhost.signing.key "{}"
{"signatures":{"test":{"ed25519:a_ZnZh":"LmPnml6iM0iR..."}}}
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-N",
"--server-name",
help="Name to give as the local homeserver. If unspecified, will be "
"read from the config file.",
)
parser.add_argument(
"-k",
"--signing-key-path",
help="Path to the file containing the private ed25519 key to sign the "
"request with.",
)
parser.add_argument(
"-c",
"--config",
default="homeserver.yaml",
help=(
"Path to synapse config file, from which the server name and/or signing "
"key path will be read. Ignored if --server-name and --signing-key-path "
"are both given."
),
)
input_args = parser.add_mutually_exclusive_group()
input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.")
input_args.add_argument(
"-i",
"--input",
type=argparse.FileType("r"),
default=sys.stdin,
help=(
"A file from which to read the JSON to be signed. If neither --input nor "
"input_data are given, JSON will be read from stdin."
),
)
parser.add_argument(
"-o",
"--output",
type=argparse.FileType("w"),
default=sys.stdout,
help="Where to write the signed JSON. Defaults to stdout.",
)
args = parser.parse_args()
if not args.server_name or not args.signing_key_path:
read_args_from_config(args)
with open(args.signing_key_path) as f:
key = read_signing_keys(f)[0]
json_to_sign = args.input_data
if json_to_sign is None:
json_to_sign = args.input.read()
try:
obj = json.loads(json_to_sign)
except JSONDecodeError as e:
print("Unable to parse input as JSON: %s" % e, file=sys.stderr)
sys.exit(1)
if not isinstance(obj, dict):
print("Input json was not an object", file=sys.stderr)
sys.exit(1)
sign_json(obj, args.server_name, key)
for c in json_encoder.iterencode(obj):
args.output.write(c)
args.output.write("\n")
def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh:
config = yaml.safe_load(fh)
if not args.server_name:
args.server_name = config["server_name"]
if not args.signing_key_path:
args.signing_key_path = config["signing_key_path"]
if __name__ == "__main__":
main()

View file

@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.23.0"
__version__ = "1.24.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -37,7 +37,7 @@ def request_registration(
exit=sys.exit,
):
url = "%s/_matrix/client/r0/admin/register" % (server_location,)
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
r = requests.get(url, verify=False)

View file

@ -14,10 +14,12 @@
# limitations under the License.
import logging
from typing import Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import Requester
logger = logging.getLogger(__name__)
@ -33,24 +35,47 @@ class AuthBlocking:
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
user_id: If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
threepid: If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
user_type: If present, is used to decide whether to check against
certain blocking reasons like MAU.
requester: If present, and the authenticated entity is a user, checks for
presence against existing MAU cohort. Passing in both a `user_id` and
`requester` is an error.
"""
if requester and user_id:
raise Exception(
"Passed in both 'user_id' and 'requester' to 'check_auth_blocking'"
)
if requester:
if requester.authenticated_entity.startswith("@"):
user_id = requester.authenticated_entity
elif requester.authenticated_entity == self._server_name:
# We never block the server from doing actions on behalf of
# users.
return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking

View file

@ -32,6 +32,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
@ -244,6 +245,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
@ -254,7 +256,13 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
sdnotify(b"READY=1")
signal.signal(signal.SIGHUP, handle_sighup)
# We defer running the sighup handlers until next reactor tick. This
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup)
register_sighup(refresh_certificate, hs)

View file

@ -21,8 +21,11 @@ class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
push_config = config.get("push", {})
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
self.push_group_unread_count_by_room = push_config.get(
"group_unread_count_by_room", True
)
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
@ -49,6 +52,9 @@ class PushConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
## Push ##
push:
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details
# like the sender, or just the event ID and room ID (`event_id_only`).
@ -61,6 +67,18 @@ class PushConfig(Config):
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
# include_content: true
# The default value is "true" to include message details. Uncomment to only
# include the event ID and room ID in push notification payloads.
#
#include_content: false
# When a push notification is received, an unread count is also sent.
# This number can either be calculated as the number of unread messages
# for the user, or the number of *rooms* the user has unread messages in.
#
# The default value is "true", meaning push clients will see the number of
# rooms with unread messages in them. Uncomment to instead send the number
# of unread messages.
#
#group_unread_count_by_room: false
"""

View file

@ -347,8 +347,9 @@ class RegistrationConfig(Config):
# email will be globally disabled.
#
# Additionally, if `msisdn` is not set, registration and password resets via msisdn
# will be disabled regardless. This is due to Synapse currently not supporting any
# method of sending SMS messages on its own.
# will be disabled regardless, and users will not be able to associate an msisdn
# identifier to their account. This is due to Synapse currently not supporting
# any method of sending SMS messages on its own.
#
# To enable using an identity server for operations regarding a particular third-party
# identifier type, set the value to the URL of that identity server as shown in the

View file

@ -90,6 +90,8 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
@ -256,6 +258,12 @@ class SAML2Config(Config):
# remote:
# - url: https://our_idp/metadata.xml
# Allowed clock difference in seconds between the homeserver and IdP.
#
# Uncomment the below to increase the accepted time difference from 0 to 3 seconds.
#
#accepted_time_diff: 3
# By default, the user has to go to our login page first. If you'd like
# to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# 'service.sp' section:
@ -377,6 +385,14 @@ class SAML2Config(Config):
# value: "staff"
# - attribute: department
# value: "sales"
# If the metadata XML contains multiple IdP entities then the `idp_entityid`
# option must be set to the entity to redirect users to.
#
# Most deployments only have a single IdP entity and so should omit this
# option.
#
#idp_entityid: 'https://our_idp/entityid'
""" % {
"config_dir_path": config_dir_path
}

View file

@ -15,13 +15,12 @@
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
MYPY = False
if MYPY:
if TYPE_CHECKING:
import synapse.events
import synapse.server

View file

@ -49,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@ -391,7 +392,7 @@ class FederationServer(FederationBase):
TRANSACTION_CONCURRENCY_LIMIT,
)
async def on_context_state_request(
async def on_room_state_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
@ -514,11 +515,12 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self, origin: str, content: JsonDict
) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content)
room_version = await self.store.get_room_version(room_id)
assert_params_in_dict(content, ["room_id"])
room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
@ -547,12 +549,11 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
async def on_send_leave_request(
self, origin: str, content: JsonDict, room_id: str
) -> dict:
async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
logger.debug("on_send_leave_request: content: %s", content)
room_version = await self.store.get_room_version(room_id)
assert_params_in_dict(content, ["room_id"])
room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
@ -748,12 +749,8 @@ class FederationServer(FederationBase):
)
return ret
async def on_exchange_third_party_invite_request(
self, room_id: str, event_dict: Dict
):
ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
async def on_exchange_third_party_invite_request(self, event_dict: Dict):
ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret
async def check_server_matches_acl(self, server_name: str, room_id: str):

View file

@ -440,13 +440,13 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateV1Servlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?"
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given context.
async def on_GET(self, origin, content, query, context):
return await self.handler.on_context_state_request(
# This is when someone asks for all data for a given room.
async def on_GET(self, origin, content, query, room_id):
return await self.handler.on_room_state_request(
origin,
context,
room_id,
parse_string_from_args(query, "event_id", None, required=False),
)
@ -463,16 +463,16 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?"
PATH = "/backfill/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, context):
async def on_GET(self, origin, content, query, room_id):
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
if not limit:
return 400, {"error": "Did not include limit param"}
return await self.handler.on_backfill_request(origin, context, versions, limit)
return await self.handler.on_backfill_request(origin, room_id, versions, limit)
class FederationQueryServlet(BaseFederationServlet):
@ -487,9 +487,9 @@ class FederationQueryServlet(BaseFederationServlet):
class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, _content, query, context, user_id):
async def on_GET(self, origin, _content, query, room_id, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
@ -511,16 +511,16 @@ class FederationMakeJoinServlet(BaseFederationServlet):
supported_versions = ["1"]
content = await self.handler.on_make_join_request(
origin, context, user_id, supported_versions=supported_versions
origin, room_id, user_id, supported_versions=supported_versions
)
return 200, content
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, context, user_id):
content = await self.handler.on_make_leave_request(origin, context, user_id)
async def on_GET(self, origin, content, query, room_id, user_id):
content = await self.handler.on_make_leave_request(origin, room_id, user_id)
return 200, content
@ -528,7 +528,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
content = await self.handler.on_send_leave_request(origin, content)
return 200, (200, content)
@ -538,43 +538,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
content = await self.handler.on_send_leave_request(origin, content)
return 200, content
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_GET(self, origin, content, query, context, event_id):
return await self.handler.on_event_auth(origin, context, event_id)
async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, room_id, event_id)
class FederationV1SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
content = await self.handler.on_send_join_request(origin, content, context)
content = await self.handler.on_send_join_request(origin, content)
return 200, (200, content)
class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
content = await self.handler.on_send_join_request(origin, content, context)
content = await self.handler.on_send_join_request(origin, content)
return 200, content
class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, context, event_id):
async def on_PUT(self, origin, content, query, room_id, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
@ -589,12 +589,12 @@ class FederationV1InviteServlet(BaseFederationServlet):
class FederationV2InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
room_version = content["room_version"]
@ -616,9 +616,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(
room_id, content
)
content = await self.handler.on_exchange_third_party_invite_request(content)
return 200, content

View file

@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True)
requester = synapse.types.create_requester(
target_user, is_guest=True, authenticated_entity=self.server_name
)
handler = self.hs.get_room_member_handler()
await handler.update_membership(
requester,

View file

@ -226,7 +226,7 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
logger.info("Checking interested services for %s" % (stream_key))
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = []
for module, config in hs.config.password_providers:
try:
self.password_providers.append(
module(config=config, account_handler=account_handler)
)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.password_providers
]
logger.info("Extra password_providers: %r", self.password_providers)
logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
if self._password_enabled:
if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
if not self._password_enabled:
login_types.remove(LoginType.PASSWORD)
self._supported_login_types = login_types
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = await self.validate_login(user_id, authdict)
# fall back to the v1 login flow
canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@ -698,8 +704,12 @@ class AuthHandler(BaseHandler):
}
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
self,
user_id: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@ -725,13 +735,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
if puppets_user_id:
logger.info(
"Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
)
else:
logger.info(
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms
user_id=user_id,
token=access_token,
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
)
# the device *should* have been registered before we got here; however,
@ -808,15 +830,157 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
self, username: str, login_submission: Dict[str, Any]
self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Also used by the user-interactive auth flow to validate auth types which don't
have an explicit UIA handler, including m.password.auth.
Args:
username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
Raises:
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
login_type = login_submission.get("type")
if not isinstance(login_type, str):
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
# ideally, we wouldn't be checking the identifier unless we know we have a login
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
#
# But the auth providers' check_auth interface requires a username, so in
# practice we can only support login methods which we can map to a username
# anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not isinstance(password, str):
raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
# map old-school login fields into new-school "identifier" fields.
identifier_dict = convert_client_dict_legacy_fields_to_identifier(
login_submission
)
# convert phone type identifiers to generic threepids
if identifier_dict["type"] == "m.id.phone":
identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
# convert threepid identifiers to user IDs
if identifier_dict["type"] == "m.id.thirdparty":
address = identifier_dict.get("address")
medium = identifier_dict.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
(medium, address), update=False
)
# Check for login providers that support 3pid login types
if login_type == LoginType.PASSWORD:
# we've already checked that there is a (valid) password field
assert isinstance(password, str)
(
canonical_user_id,
callback_3pid,
) = await self.check_password_provider_3pid(medium, address, password)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
return canonical_user_id, callback_3pid
# No password providers were able to handle this 3pid
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
# We mark that we've failed to log in here, as
# `check_password_provider_3pid` might have returned `None` due
# to an incorrect password, rather than the account not
# existing.
#
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
(medium, address)
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier_dict = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier_dict["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
username = identifier_dict.get("user")
if not username:
raise SynapseError(400, "User identifier is missing 'user' key")
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
# Check if we've hit the failed ratelimit (but don't update it)
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
)
try:
return await self._validate_userid_login(username, login_submission)
except LoginError:
# The user has failed to log in, so we need to update the rate
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
qualified_user_id.lower()
)
raise
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
Args:
username: the username, from the identifier dict
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
@ -827,38 +991,18 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type")
# we already checked that we have a valid login type
assert isinstance(login_type, str)
known_login_type = False
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not password:
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
is_valid = await provider.check_password(qualified_user_id, password)
if is_valid:
return qualified_user_id, None
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@ -883,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
password = login_submission["password"]
assert isinstance(password, str)
canonical_user_id = await self._check_local_password(
qualified_user_id, password # type: ignore
qualified_user_id, password
)
if canonical_user_id:
@ -922,18 +1068,8 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
return result
return None, None
@ -992,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
if user_info.token_id is not None:
@ -1030,7 +1161,6 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
@ -1358,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""
@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
self._supported_login_types = {}
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
def __str__(self):
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
login_type: the login type being attempted - one of the types returned by
get_supported_login_types()
login_dict: the dictionary of login secrets passed by the client.
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
g = getattr(self._pp, "check_auth", None)
if not g:
return None
result = await g(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
from typing import Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service.
Args:
hs (synapse.server.HomeServer)
hs
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@ -200,27 +203,57 @@ class CasHandler:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)
else:
if not registered_user_id:
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=user_display_name,
user_agent_ips=(user_agent, ip_address),
# Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address
)
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
if session:
await self._auth_handler.complete_sso_ui_auth(
user_id, session, request,
)
else:
# If this not a UI auth request than there must be a redirect URL.
assert client_redirect_url
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.
Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
"""
localpart = map_username_to_mxid_localpart(remote_user_id)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=display_name,
user_agent_ips=[(user_agent, ip_address)],
)
return registered_user_id

View file

@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites:
try:
await self._room_member_handler.update_membership(
create_requester(user),
create_requester(user, authenticated_entity=self._server_name),
user,
room.room_id,
"leave",
@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id)
try:
await self._room_member_handler.update_membership(
create_requester(user),
create_requester(user, authenticated_entity=self._server_name),
user,
room_id,
"leave",

View file

@ -55,6 +55,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@ -67,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@ -152,12 +153,14 @@ class FederationHandler(BaseHandler):
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
hs
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
self._maybe_store_room_on_outlier_membership = (
self.store.maybe_store_room_on_outlier_membership
)
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
@ -1617,7 +1620,7 @@ class FederationHandler(BaseHandler):
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
await self._maybe_store_room_on_invite(
await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=room_version
)
@ -2686,7 +2689,7 @@ class FederationHandler(BaseHandler):
)
async def on_exchange_third_party_invite_request(
self, room_id: str, event_dict: JsonDict
self, event_dict: JsonDict
) -> None:
"""Handle an exchange_third_party_invite request from a remote server
@ -2694,12 +2697,11 @@ class FederationHandler(BaseHandler):
into a normal m.room.member invite.
Args:
room_id: The ID of the room.
event_dict (dict[str, Any]): Dictionary containing the event body.
event_dict: Dictionary containing the event body.
"""
room_version = await self.store.get_room_version_id(room_id)
assert_params_in_dict(event_dict, ["room_id"])
room_version = await self.store.get_room_version_id(event_dict["room_id"])
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.

View file

@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
self.hs.get_clock().time_msec()
+ self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(

View file

@ -474,7 +474,7 @@ class EventCreationHandler:
Returns:
Tuple of created event, Context
"""
await self.auth.check_auth_blocking(requester.user.to_string())
await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
@ -621,7 +621,13 @@ class EventCreationHandler:
if requester.app_service is not None:
return
user_id = requester.user.to_string()
user_id = requester.authenticated_entity
if not user_id.startswith("@"):
# The authenticated entity might not be a user, e.g. if it's the
# server puppetting the user.
return
user = UserID.from_string(user_id)
# exempt the system notices user
if (
@ -641,9 +647,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version:
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(
requester.user.localpart
)
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@ -1255,7 +1259,7 @@ class EventCreationHandler:
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
requester = create_requester(user_id)
requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
requester,
@ -1276,11 +1280,6 @@ class EventCreationHandler:
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
)
return True
except ConsentNotGivenError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "
"lack of consent. Will try another user" % (room_id, user_id)
)
except AuthError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@ -34,7 +35,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@ -83,17 +85,12 @@ class OidcError(Exception):
return self.error
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class OidcHandler:
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@ -120,36 +117,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@ -571,7 +545,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error.
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@ -609,7 +583,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._render_error(request, error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@ -619,7 +593,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
@ -637,7 +613,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
@ -651,17 +629,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e))
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
logger.debug("Exchanging code")
@ -670,7 +650,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description)
self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@ -683,7 +663,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e))
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@ -691,7 +671,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e))
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
@ -705,7 +685,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@ -770,7 +750,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@ -845,7 +825,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@ -885,71 +865,77 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
mapper_signature = inspect.signature(
self._user_mapping_provider.map_user_attributes
)
supports_failures = "failures" in mapper_signature.parameters
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id,
)
async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
"""
Call the mapping provider to map the OIDC userinfo and token to user attributes.
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
try:
This is backwards compatibility for abstraction for the SSO handler.
"""
if supports_failures:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token, failures
)
else:
# If the mapping provider does not support processing failures,
# do not continually generate the same Matrix ID since it will
# continue to already be in use. Note that the error raised is
# arbitrary and will get turned into a MappingException.
if failures:
raise RuntimeError(
"Mapping provider does not support de-duplicating Matrix IDs"
)
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
userinfo, token
)
except Exception as e:
raise MappingException(
"Could not extract user attributes from OIDC response: " + str(e)
)
logger.debug(
"Retrieved user attributes from user mapping provider: %r", attributes
)
return UserAttributes(**attributes)
if not attributes["localpart"]:
raise MappingException("localpart is empty")
localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
if users:
async def grandfather_existing_users() -> Optional[str]:
if self._allow_existing_users:
# If allowing existing users we want to generate a single localpart
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
registered_user_id = next(iter(users))
previously_registered_user_id = next(iter(users))
elif user_id in users:
registered_user_id = user_id
previously_registered_user_id = user_id
else:
# Do not attempt to continue generating Matrix IDs.
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, list(users.keys())
user_id, users
)
)
else:
# This mxid is taken
raise MappingException("mxid '{}' is already taken".format(user_id))
else:
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
return previously_registered_user_id
return None
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
grandfather_existing_users,
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id
UserAttribute = TypedDict(
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@ -992,13 +978,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
failures: How many times a call to this function with this
UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@ -1098,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)
# Append suffix integer if last call to this function failed to produce
# a usable mxid.
localpart += str(failures) if failures else ""
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@ -1111,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]

View file

@ -299,15 +299,20 @@ class PaginationHandler:
"""
return self._purges_by_id.get(purge_id)
async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database"""
async def purge_room(self, room_id: str, force: bool = False) -> None:
"""Purge the given room from the database.
Args:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
if joined:
raise SynapseError(400, "Users are still joined to this room")

View file

@ -25,7 +25,7 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
from typing import Dict, Iterable, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
MYPY = False
if MYPY:
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)

View file

@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
requester = create_requester(target_user)
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity,
)
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler):
# Same like set_displayname
if by_admin:
requester = create_requester(target_user)
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)

View file

@ -158,7 +158,8 @@ class ReceiptEventSource:
if from_key == to_key:
return [], to_key
# We first need to fetch all new receipts
# Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
# by most recent.
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)

View file

@ -15,10 +15,12 @@
"""Contains functions for registering clients."""
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
@ -70,7 +71,11 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
self, localpart, guest_access_token=None, assigned_user_id=None, allow_invalid=False
self,
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
allow_invalid: bool = False,
):
# meow: allow admins to register invalid user ids
if not allow_invalid:
@ -143,39 +148,45 @@ class RegistrationHandler(BaseHandler):
async def register_user(
self,
localpart=None,
password_hash=None,
guest_access_token=None,
make_guest=False,
admin=False,
threepid=None,
user_type=None,
default_display_name=None,
address=None,
bind_emails=[],
by_admin=False,
user_agent_ips=None,
):
localpart: Optional[str] = None,
password_hash: Optional[str] = None,
guest_access_token: Optional[str] = None,
make_guest: bool = False,
admin: bool = False,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
default_display_name: Optional[str] = None,
address: Optional[str] = None,
bind_emails: List[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
) -> str:
"""Registers a new client on the server.
Args:
localpart: The local part of the user ID to register. If None,
one will be generated.
password_hash (str|None): The hashed password to assign to this user so they can
password_hash: The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from
guest_access_token: The access token used when this was a guest
account.
make_guest: True if the the new user should be guest,
false to add a regular user account.
admin: True if the user should be registered as a server admin.
threepid: The threepid used for registering, if any.
user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname
default_display_name: if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
address: the IP address used to perform the registration.
bind_emails: list of emails to bind to this account.
by_admin: True if this registration is being made via the
admin api, otherwise False.
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
Returns:
str: user_id
The registere user_id.
Raises:
SynapseError if there was a problem registering.
"""
@ -241,8 +252,10 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
user = None
while not user:
# If a default display name is not given, generate one.
generate_display_name = default_display_name is None
# This breaks on successful registration *or* errors after 10 failures.
while True:
# Fail after being unable to find a suitable ID a few times
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
@ -251,7 +264,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
if default_display_name is None:
if generate_display_name:
default_display_name = localpart
try:
await self.register_with_store(
@ -267,8 +280,6 @@ class RegistrationHandler(BaseHandler):
break
except SynapseError:
# if user id is taken, just generate another
user = None
user_id = None
fail_count += 1
if not self.hs.config.user_consent_at_registration:
@ -300,7 +311,7 @@ class RegistrationHandler(BaseHandler):
return user_id
async def _create_and_join_rooms(self, user_id: str):
async def _create_and_join_rooms(self, user_id: str) -> None:
"""
Create the auto-join rooms and join or invite the user to them.
@ -323,7 +334,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False
if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester(
self.hs.config.registration.auto_join_user_id
self.hs.config.registration.auto_join_user_id,
authenticated_entity=self._server_name,
)
# If the room requires an invite, add the user to the list of invites.
@ -335,7 +347,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent.
requires_join = True
else:
fake_requester = create_requester(user_id)
fake_requester = create_requester(
user_id, authenticated_entity=self._server_name
)
# Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@ -368,7 +382,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it.
if requires_join:
await room_member_handler.update_membership(
requester=create_requester(user_id),
requester=create_requester(
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id),
room_id=info["room_id"],
# Since it was just created, there are no remote hosts.
@ -376,15 +392,10 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
async def _join_rooms(self, user_id: str):
async def _join_rooms(self, user_id: str) -> None:
"""
Join or invite the user to the auto-join rooms.
@ -430,9 +441,13 @@ class RegistrationHandler(BaseHandler):
# Send the invite, if necessary.
if requires_invite:
# If an invite is required, there must be a auto-join user ID.
assert self.hs.config.registration.auto_join_user_id
await room_member_handler.update_membership(
requester=create_requester(
self.hs.config.registration.auto_join_user_id
self.hs.config.registration.auto_join_user_id,
authenticated_entity=self._server_name,
),
target=UserID.from_string(user_id),
room_id=room_id,
@ -443,7 +458,9 @@ class RegistrationHandler(BaseHandler):
# Send the join.
await room_member_handler.update_membership(
requester=create_requester(user_id),
requester=create_requester(
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id),
room_id=room_id,
remote_room_hosts=remote_room_hosts,
@ -458,7 +475,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
async def _auto_join_rooms(self, user_id: str):
async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@ -481,16 +498,16 @@ class RegistrationHandler(BaseHandler):
else:
await self._join_rooms(user_id)
async def post_consent_actions(self, user_id):
async def post_consent_actions(self, user_id: str) -> None:
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
user_id: The user to join
"""
await self._auto_join_rooms(user_id)
async def appservice_register(self, user_localpart, as_token):
async def appservice_register(self, user_localpart: str, as_token: str) -> str:
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@ -515,7 +532,9 @@ class RegistrationHandler(BaseHandler):
)
return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
def check_user_id_not_appservice_exclusive(
self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
) -> None:
# don't allow people to register the server notices mxid
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
@ -539,12 +558,12 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
def check_registration_ratelimit(self, address):
def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Args:
address (str|None): the IP address used to perform the registration. If this is
address: the IP address used to perform the registration. If this is
None, no ratelimiting will be performed.
Raises:
@ -555,42 +574,39 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(address)
def register_with_store(
async def register_with_store(
self,
user_id,
password_hash=None,
was_guest=False,
make_guest=False,
appservice_id=None,
create_profile_with_displayname=None,
admin=False,
user_type=None,
address=None,
shadow_banned=False,
):
user_id: str,
password_hash: Optional[str] = None,
was_guest: bool = False,
make_guest: bool = False,
appservice_id: Optional[str] = None,
create_profile_with_displayname: Optional[str] = None,
admin: bool = False,
user_type: Optional[str] = None,
address: Optional[str] = None,
shadow_banned: bool = False,
) -> None:
"""Register user in the datastore.
Args:
user_id (str): The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
user_id: The desired user ID to register.
password_hash: Optional. The password hash for this user.
was_guest: Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
make_guest: True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname: Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
admin: is an admin user?
user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration.
shadow_banned (bool): Whether to shadow-ban the user
Returns:
Awaitable
address: the IP address used to perform the registration.
shadow_banned: Whether to shadow-ban the user
"""
if self.hs.config.worker_app:
return self._register_client(
await self._register_client(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@ -603,7 +619,7 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
else:
return self.store.register_user(
await self.store.register_user(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@ -616,22 +632,24 @@ class RegistrationHandler(BaseHandler):
)
async def register_device(
self, user_id, device_id, initial_display_name, is_guest=False
):
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
user_id: full canonical @user:id
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
Returns:
tuple[str, str]: Tuple of device ID and access token
Tuple of device ID and access token
"""
if self.hs.config.worker_app:
@ -651,7 +669,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
device_id = await self.device_handler.check_device_registered(
registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@ -661,20 +679,21 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
)
return (device_id, access_token)
return (registered_device_id, access_token)
async def post_registration_actions(self, user_id, auth_result, access_token):
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
) -> None:
"""A user has completed registration
Args:
user_id (str): The user ID that consented
auth_result (dict): The authenticated credentials of the newly
registered user.
access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled.
user_id: The user ID that consented
auth_result: The authenticated credentials of the newly registered user.
access_token: The access token of the newly logged in device, or
None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
await self._post_registration_client(
@ -700,19 +719,20 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.TERMS in auth_result:
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
async def _on_user_consented(self, user_id, consent_version):
async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
Args:
user_id (str): The user ID that consented.
consent_version (str): version of the policy the user has
consented to.
user_id: The user ID that consented.
consent_version: version of the policy the user has consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
async def _register_email_threepid(self, user_id, threepid, token):
async def _register_email_threepid(
self, user_id: str, threepid: dict, token: Optional[str]
) -> None:
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@ -721,10 +741,9 @@ class RegistrationHandler(BaseHandler):
Must be called on master.
Args:
user_id (str): id of user
threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged
in.
user_id: id of user
threepid: m.login.email.identity auth response
token: access_token for the user, or None if not logged in.
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@ -750,6 +769,8 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
# The token better still exist.
assert user_tuple
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
@ -764,14 +785,14 @@ class RegistrationHandler(BaseHandler):
data={},
)
async def _register_msisdn_threepid(self, user_id, threepid):
async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
"""Add a phone number as a 3pid identifier
Must be called on master.
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
user_id: id of user
threepid: m.login.msisdn auth response
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])

View file

@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
await self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(requester=requester)
if (
self._server_notices_mxid is not None
@ -1269,7 +1269,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,)
)
room_creator_requester = create_requester(new_room_user_id)
room_creator_requester = create_requester(
new_room_user_id, authenticated_entity=requester_user_id
)
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
@ -1309,7 +1311,9 @@ class RoomShutdownHandler:
try:
# Kick users from room
target_requester = create_requester(user_id)
target_requester = create_requester(
user_id, authenticated_entity=requester_user_id
)
_, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,

View file

@ -31,7 +31,6 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@ -347,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# later on.
content = dict(content)
if not self.allow_per_room_profiles or requester.shadow_banned:
# allow the server notices mxid to set room-level profile
is_requester_server_notices_user = (
self._server_notices_mxid is not None
and requester.user.to_string() == self._server_notices_mxid
)
if (
not self.allow_per_room_profiles and not is_requester_server_notices_user
) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
@ -515,10 +522,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
invite = await self.store.get_invite_for_local_user_in_room(
user_id=target.to_string(), room_id=room_id
) # type: Optional[RoomsForUser]
if not invite:
(
current_membership_type,
current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id
)
if (
current_membership_type != Membership.INVITE
or not current_membership_event_id
):
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
@ -528,6 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(404, "Not a known room")
invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
@ -965,6 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room")
self._server_name = hs.hostname
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
@ -1059,7 +1074,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
requester = types.create_requester(user, None, False, False, None)
requester = types.create_requester(
user, authenticated_entity=self._server_name
)
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
@ -1104,32 +1121,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
return await self._locally_reject_invite(
return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content
)
async def _locally_reject_invite(
async def _generate_local_out_of_band_leave(
self,
invite_event: EventBase,
previous_membership_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""Generate a local invite rejection
"""Generate a local leave event for a room
This is called after we fail to reject an invite via a remote server. It
generates an out-of-band membership event locally.
This can be called after we e.g fail to reject an invite via a remote server.
It generates an out-of-band membership event locally.
Args:
invite_event: the invite to be rejected
previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
requester: user making the request, according to the access token
content: additional content to include in the leave event.
Normally an empty dict.
"""
room_id = invite_event.room_id
target_user = invite_event.state_key
Returns:
A tuple containing (event_id, stream_id of the leave event)
"""
room_id = previous_membership_event.room_id
target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE
@ -1141,12 +1160,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
# the auth events for the new event are the same as that of the previous event, plus
# the event itself.
#
# the prev_events are just the invite.
prev_event_ids = [invite_event.event_id]
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
# the prev_events consist solely of the previous membership event.
prev_event_ids = [previous_membership_event.event_id]
auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,

View file

@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
import synapse.server
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@ -57,17 +54,14 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
class SamlHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@ -88,26 +82,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@ -124,13 +101,13 @@ class SamlHandler:
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url
entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec()
now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@ -171,12 +148,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
self._render_error(
self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
self._render_error(
self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@ -184,7 +161,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
self._render_error(
self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@ -210,7 +187,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error(
self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@ -226,7 +203,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@ -272,20 +249,26 @@ class SamlHandler:
"Failed to extract remote user id from SAML response"
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
"""
Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
result = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
emails=result.get("emails", []),
)
async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@ -294,75 +277,35 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname
map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
logger.info(
logger.debug(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i, client_redirect_url=client_redirect_url,
return None
with (await self._mapping_lock.queue(self._auth_provider_id)):
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
grandfather_existing_users,
)
logger.debug(
"Retrieved SAML attributes from user mapping provider: %s "
"(attempt %d)",
attribute_dict,
i,
)
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
raise MappingException(
"Error parsing SAML2 response: SAML mapping provider plugin "
"did not return a mxid_localpart value"
)
displayname = attribute_dict.get("displayname")
emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
# This mxid is free
break
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise MappingException(
"Unable to generate a Matrix ID from the SAML response"
)
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:
@ -474,11 +417,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
base_mxid_localpart = self._mxid_mapper(mxid_source)
localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = base_mxid_localpart + (str(failures) if failures else "")
# a usable mxid.
localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead

227
synapse/handlers/sso.py Normal file
View file

@ -0,0 +1,227 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
@attr.s
class UserAttributes:
localpart = attr.ib(type=str)
display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=List[str], default=attr.Factory(list))
class SsoHandler(BaseHandler):
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
def render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and responds with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
) -> Optional[str]:
"""
Maps the user ID of a remote IdP to a mxid for a previously seen user.
If the user has not been seen yet, this will return None.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The user ID according to the remote IdP. This might
be an e-mail address, a GUID, or some other form. It must be
unique and immutable.
Returns:
The mxid of a previously seen user.
"""
logger.debug(
"Looking for existing mapping for user %s:%s",
auth_provider_id,
remote_user_id,
)
# Check if we already have a mapping for this user.
previously_registered_user_id = await self.store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
# A match was found, return the user ID.
if previously_registered_user_id is not None:
logger.info(
"Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
auth_provider_id,
remote_user_id,
previously_registered_user_id,
)
return previously_registered_user_id
# No match.
return None
async def get_mxid_from_sso(
self,
auth_provider_id: str,
remote_user_id: str,
user_agent: str,
ip_address: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
This first checks if the SSO ID has previously been linked to a matrix ID,
if it has that matrix ID is returned regardless of the current mapping
logic.
If a callable is provided for grandfathering users, it is called and can
potentially return a matrix ID to use. If it does, the SSO ID is linked to
this matrix ID for subsequent calls.
The mapping function is called (potentially multiple times) to generate
a localpart for the user.
If an unused localpart is generated, the user is registered from the
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
Returns:
The user ID associated with the SSO response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
if previously_registered_user_id:
return previously_registered_user_id
# Check for grandfathering of users.
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
# Otherwise, generate a new user.
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
except Exception as e:
raise MappingException(
"Could not extract user attributes from SSO response: " + str(e)
)
logger.debug(
"Retrieved user attributes from user mapping provider: %r (attempt %d)",
attributes,
i,
)
if not attributes.localpart:
raise MappingException(
"Error parsing SSO response: SSO mapping provider plugin "
"did not return a localpart value"
)
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
if not await self.store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=attributes.localpart,
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
)
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id

View file

@ -31,6 +31,7 @@ from synapse.types import (
Collection,
JsonDict,
MutableStateMap,
Requester,
RoomStreamToken,
StateMap,
StreamToken,
@ -260,6 +261,7 @@ class SyncHandler:
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
@ -273,7 +275,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap(
sync_config.request_key,

View file

@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import urllib.parse
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
) -> bool:
"""
Compares an IP address to allowed and disallowed IP sets.
Args:
ip_address (netaddr.IPAddress)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
ip_address: The IP address to check
ip_whitelist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
Returns:
True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
def __init__(self, reactor, ip_whitelist, ip_blacklist):
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
):
"""
Args:
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
reactor: The twisted reactor.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def resolveHostName(self, recv, hostname, portNumber=0):
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
r = recv()
addresses = []
addresses = [] # type: List[IAddress]
def _callback():
def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress):
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
def addressResolved(address):
def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
def resolutionComplete():
def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
def __init__(
self,
agent: IAgent,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
):
"""
Args:
agent (twisted.web.client.Agent): The Agent to wrap.
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
agent: The Agent to wrap.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@ -226,24 +256,24 @@ class SimpleHttpClient:
def __init__(
self,
hs,
treq_args={},
ip_whitelist=None,
ip_blacklist=None,
http_proxy=None,
https_proxy=None,
user_agent=None,
hs: "HomeServer",
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
user_agent: Optional[str] = None,
):
"""
Args:
hs (synapse.server.HomeServer)
treq_args (dict): Extra keyword arguments to be given to treq.request.
ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
hs
treq_args: Extra keyword arguments to be given to treq.request.
ip_blacklist: The IP addresses that are blacklisted that
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
http_proxy (bytes): proxy server to use for http connections. host[:port]
https_proxy (bytes): proxy server to use for https connections. host[:port]
http_proxy: proxy server to use for http connections. host[:port]
https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@ -307,7 +337,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@ -398,7 +427,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
args: Mapping[str, Union[str, List[str]]] = {},
args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@ -423,9 +452,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
"utf8"
)
query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -433,7 +460,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@ -480,7 +507,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@ -496,7 +523,10 @@ class SimpleHttpClient:
)
async def get_json(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@ -517,7 +547,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@ -526,7 +556,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
args: QueryParams = {},
args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@ -547,9 +577,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@ -559,7 +589,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@ -575,7 +605,10 @@ class SimpleHttpClient:
)
async def get_raw(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@ -593,13 +626,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@ -642,7 +675,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@ -650,12 +683,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@ -669,7 +703,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@ -697,18 +731,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@ -722,7 +754,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@ -733,35 +765,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def readBodyToFile(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
max_size: The maximum file size to allow.
Returns:
A Deferred which resolves to the length of the read body.
"""
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
Args:
args: The query arguments, a mapping of string to string or list of strings.
def encode_urlencode_arg(arg):
if isinstance(arg, str):
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
Returns:
The query arguments encoded as bytes.
"""
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.exception(e)
query_str = urllib.parse.urlencode(encoded_args, True)
return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):

View file

@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import List
import urllib.parse
from typing import List, Optional
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import Agent, HTTPConnectionPool
from twisted.internet.interfaces import (
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
)
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -44,30 +48,30 @@ class MatrixFederationAgent:
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
reactor (IReactor): twisted reactor to use for underlying requests
reactor: twisted reactor to use for underlying requests
tls_client_options_factory (FederationPolicyForHTTPS|None):
tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
user_agent (bytes):
user_agent:
The user agent header to use for federation requests.
_srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
_srv_resolver:
SrvResolver implementation to use for looking up SRV records. None
to use a default implementation.
_well_known_resolver (WellKnownResolver|None):
_well_known_resolver:
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
self,
reactor,
tls_client_options_factory,
user_agent,
_srv_resolver=None,
_well_known_resolver=None,
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@ -99,15 +103,20 @@ class MatrixFederationAgent:
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
"""
Args:
method (bytes): HTTP method: GET/POST/etc
uri (bytes): Absolute URI to be retrieved
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
bodyProducer (twisted.web.iweb.IBodyProducer|None):
method: HTTP method: GET/POST/etc
uri: Absolute URI to be retrieved
headers:
HTTP headers to send with the request, or None to send no extra headers.
bodyProducer:
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
@ -123,6 +132,9 @@ class MatrixFederationAgent:
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
# There must be a valid hostname.
assert parsed_uri.hostname
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
def __init__(self, reactor, tls_client_options_factory, srv_resolver):
def __init__(
self,
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: Optional[SrvResolver],
):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
resolution (i.e. via SRV). Does not check for well-known delegation.
Args:
reactor (IReactor)
tls_client_options_factory (ClientTLSOptionsFactory|None):
reactor: twisted reactor to use for underlying requests
tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver): The SRV resolver to use
parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
to connect to.
srv_resolver: The SRV resolver to use
parsed_uri: The parsed URI that we're wanting to connect to.
"""
def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
def __init__(
self,
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: SrvResolver,
parsed_uri: URI,
):
self._reactor = reactor
self._parsed_uri = parsed_uri
@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
def connect(self, protocol_factory):
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
async def _do_connect(self, protocol_factory):
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
first_exception = None
server_list = await self._resolve_server()
@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
return [Server(host, 8448)]
def _is_ip_literal(host):
def _is_ip_literal(host: bytes) -> bool:
"""Test if the given host name is either an IPv4 or IPv6 literal.
Args:
host (bytes)
host: The host name to check
Returns:
bool
True if the hostname is an IP address literal.
"""
host = host.decode("ascii")
host_str = host.decode("ascii")
try:
IPAddress(host)
IPAddress(host_str)
return True
except AddrFormatError:
return False

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import time
@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IAgent, IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@ -81,11 +81,11 @@ class WellKnownResolver:
def __init__(
self,
reactor,
agent,
user_agent,
well_known_cache=None,
had_well_known_cache=None,
reactor: IReactorTime,
agent: IAgent,
user_agent: bytes,
well_known_cache: Optional[TTLCache] = None,
had_well_known_cache: Optional[TTLCache] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@ -127,7 +127,7 @@ class WellKnownResolver:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
) # type: Tuple[Optional[bytes], float]
) # type: Optional[bytes], float
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:

View file

@ -17,8 +17,9 @@ import cgi
import logging
import random
import sys
import urllib
import urllib.parse
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@ -27,25 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.client import (
BlacklistingAgentWrapper,
IPBlacklistingResolver,
encode_query_args,
readBodyToFile,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@ -54,6 +57,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@ -76,47 +80,44 @@ MAXINT = sys.maxsize
_next_id = 1
QueryArgs = Dict[str, Union[str, List[str]]]
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
method = attr.ib()
method = attr.ib(type=str)
"""HTTP method
:type: str
"""
path = attr.ib()
path = attr.ib(type=str)
"""HTTP path
:type: str
"""
destination = attr.ib()
destination = attr.ib(type=str)
"""The remote server to send the HTTP request to.
:type: str"""
"""
json = attr.ib(default=None)
json = attr.ib(default=None, type=Optional[JsonDict])
"""JSON to send in the body.
:type: dict|None
"""
json_callback = attr.ib(default=None)
json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
"""A callback to generate the JSON.
:type: func|None
"""
query = attr.ib(default=None)
query = attr.ib(default=None, type=Optional[dict])
"""Query arguments.
:type: dict|None
"""
txn_id = attr.ib(default=None)
txn_id = attr.ib(default=None, type=Optional[str])
"""Unique ID for this request (for logging)
:type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
global _next_id
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
@ -136,7 +137,7 @@ class MatrixFederationRequest:
)
object.__setattr__(self, "uri", uri)
def get_json(self):
def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
return self.json_callback()
return self.json
@ -148,7 +149,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
):
) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@ -160,7 +161,7 @@ async def _handle_json_response(
start_ms: Timestamp when request was made
Returns:
dict: parsed JSON response
The parsed JSON response
"""
try:
check_content_type_is_json(response.headers)
@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@ -266,27 +265,29 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
async def _send_request_with_optional_trailing_slash(
self, request, try_trailing_slash_on_400=False, **send_request_args
):
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
**send_request_args
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
Args:
request (MatrixFederationRequest): details of request to be sent
try_trailing_slash_on_400 (bool): Whether on receiving a 400
request: details of request to be sent
try_trailing_slash_on_400: Whether on receiving a 400
'M_UNRECOGNIZED' from the server to retry the request with a
trailing slash appended to the request path.
send_request_args (Dict): A dictionary of arguments to pass to
`_send_request()`.
send_request_args: A dictionary of arguments to pass to `_send_request()`.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
Returns:
Dict: Parsed JSON response body.
Parsed JSON response body.
"""
try:
response = await self._send_request(request, **send_request_args)
@ -313,24 +314,26 @@ class MatrixFederationHttpClient:
async def _send_request(
self,
request,
retry_on_dns_fail=True,
timeout=None,
long_retries=False,
ignore_backoff=False,
backoff_on_404=False,
):
request: MatrixFederationRequest,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
long_retries: bool = False,
ignore_backoff: bool = False,
backoff_on_404: bool = False,
) -> IResponse:
"""
Sends a request to the given server.
Args:
request (MatrixFederationRequest): details of request to be sent
request: details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers
retry_on_dns_fail: true if the request should be retied on DNS failures
timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
60s by default.
long_retries (bool): whether to use the long retry algorithm.
long_retries: whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
@ -346,14 +349,13 @@ class MatrixFederationHttpClient:
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_404: Back off if we get a 404
Returns:
twisted.web.client.Response: resolves with the HTTP
response object on success.
Resolves with the HTTP response object on success.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
@ -404,7 +406,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
headers_dict = {}
headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@ -435,7 +437,7 @@ class MatrixFederationHttpClient:
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
)
) # type: Optional[IBodyProducer]
else:
producer = None
auth_headers = self.build_auth_headers(
@ -524,14 +526,16 @@ class MatrixFederationHttpClient:
)
body = None
e = HttpResponseException(response.code, response_phrase, body)
exc = HttpResponseException(
response.code, response_phrase, body
)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
raise RequestSendFailed(e, can_retry=True) from e
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise e
raise exc
break
except RequestSendFailed as e:
@ -582,22 +586,27 @@ class MatrixFederationHttpClient:
return response
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None
):
self,
destination: Optional[bytes],
method: bytes,
url_bytes: bytes,
content: Optional[JsonDict] = None,
destination_is: Optional[bytes] = None,
) -> List[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
destination (bytes|None): The destination homeserver of the request.
destination: The destination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
url_bytes (bytes): The URI path of the request
content (object): The body of the request
destination_is (bytes): As 'destination', but if the destination is an
method: The HTTP method of the request
url_bytes: The URI path of the request
content: The body of the request
destination_is: As 'destination', but if the destination is an
identity server
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
A list of headers to be added as "Authorization:" headers
"""
request = {
"method": method.decode("ascii"),
@ -629,33 +638,32 @@ class MatrixFederationHttpClient:
async def put_json(
self,
destination,
path,
args={},
data={},
json_data_callback=None,
long_retries=False,
timeout=None,
ignore_backoff=False,
backoff_on_404=False,
try_trailing_slash_on_400=False,
):
destination: str,
path: str,
args: Optional[QueryArgs] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
""" Sends the specified json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
args (dict): query params
data (dict): A dict containing the data that will be used as
destination: The remote server to send the HTTP request to.
path: The HTTP path.
args: query params
data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
json_data_callback: A callable returning the dict to
use as the request body.
long_retries (bool): whether to use the long retry algorithm. See
long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -663,19 +671,19 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
backoff_on_404: True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests).
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -721,29 +729,28 @@ class MatrixFederationHttpClient:
async def post_json(
self,
destination,
path,
data={},
long_retries=False,
timeout=None,
ignore_backoff=False,
args={},
):
destination: str,
path: str,
data: Optional[JsonDict] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
""" Sends the specified json data using POST
Args:
destination (str): The remote server to send the HTTP request
to.
destination: The remote server to send the HTTP request to.
path (str): The HTTP path.
path: The HTTP path.
data (dict): A dict containing the data that will be used as
data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
long_retries (bool): whether to use the long retry algorithm. See
long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -751,10 +758,10 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data and
ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
args (dict): query params
args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@ -795,26 +802,25 @@ class MatrixFederationHttpClient:
async def get_json(
self,
destination,
path,
args=None,
retry_on_dns_fail=True,
timeout=None,
ignore_backoff=False,
try_trailing_slash_on_400=False,
):
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
Args:
destination (str): The remote server to send the HTTP request
to.
destination: The remote server to send the HTTP request to.
path (str): The HTTP path.
path: The HTTP path.
args (dict|None): A dictionary used to create query strings, defaults to
args: A dictionary used to create query strings, defaults to
None.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -822,14 +828,14 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -870,24 +876,23 @@ class MatrixFederationHttpClient:
async def delete_json(
self,
destination,
path,
long_retries=False,
timeout=None,
ignore_backoff=False,
args={},
):
destination: str,
path: str,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
destination: The remote server to send the HTTP request to.
path: The HTTP path.
long_retries (bool): whether to use the long retry algorithm. See
long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -895,12 +900,12 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data and
ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
args (dict): query params
args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -938,25 +943,25 @@ class MatrixFederationHttpClient:
async def get_file(
self,
destination,
path,
destination: str,
path: str,
output_stream,
args={},
retry_on_dns_fail=True,
max_size=None,
ignore_backoff=False,
):
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET.
output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
tuple[int, dict]: Resolves with an (int,dict) tuple of
Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@ -980,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
d = _readBodyToFile(response, output_stream, max_size)
d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
if reason.check(ResponseDone):
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e):
return repr(e)
def check_content_type_is_json(headers):
def check_content_type_is_json(headers: Headers) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
headers (twisted.web.http_headers.Headers): headers to check
headers: headers to check
Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
@ -1078,18 +1049,3 @@ def check_content_type_is_json(headers):
),
can_retry=False,
)
def encode_query_args(args):
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes.encode("utf8")

View file

@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
request,
error_code,
error_dict,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
request, error_code, error_dict, send_cors=True,
)
@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@ -587,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@ -598,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
pretty_print: Whether to include indentation and line-breaks in the
resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@ -615,9 +607,6 @@ def respond_with_json(
)
return None
if pretty_print:
encoder = iterencode_pretty_printed_json
else:
if canonical_json:
encoder = iterencode_canonical_json
else:
@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@ -759,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
def _request_user_agent_is_curl(request: Request) -> bool:
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True
return False

View file

@ -49,6 +49,7 @@ class ModuleApi:
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
# We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
@ -336,7 +337,9 @@ class ModuleApi:
SynapseError if the event was not allowed.
"""
# Create a requester object
requester = create_requester(event_dict["sender"])
requester = create_requester(
event_dict["sender"], authenticated_entity=self._server_name
)
# Create and send the event
(

View file

@ -75,6 +75,7 @@ class HttpPusher:
self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@ -136,7 +137,11 @@ class HttpPusher:
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
badge = await push_tools.get_badge_count(
self.hs.get_datastore(),
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
await self._send_badge(badge)
def on_timer(self):
@ -283,7 +288,11 @@ class HttpPusher:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
badge = await push_tools.get_badge_count(
self.hs.get_datastore(),
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:

View file

@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
async def get_badge_count(store, user_id):
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
if notifs["notify_count"] == 0:
continue
if group_by_room:
# return one badge count per conversation
badge += 1
else:
# increment the badge count by the number of unread messages in the room
badge += notifs["notify_count"]
return badge

View file

@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return 200, {}
class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Request format:
POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id
{
"room_version": "1",
}
"""
NAME = "store_room_on_invite"
NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",)
def __init__(self, hs):
@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
async def _handle_request(self, request, room_id):
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_invite(room_id, room_version)
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
@ -291,4 +291,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server)

View file

@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.http import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(
requester, room_id, user_id, remote_room_hosts, content
):
async def _serialize_payload( # type: ignore
requester: Requester,
room_id: str,
user_id: str,
remote_room_hosts: List[str],
content: JsonDict,
) -> JsonDict:
"""
Args:
requester(Requester)
room_id (str)
user_id (str)
remote_room_hosts (list[str]): Servers to try and join via
content(dict): The event content to use for the join event
requester: The user making the request according to the access token
room_id: The ID of the room.
user_id: The ID of the user.
remote_room_hosts: Servers to try and join via
content: The event content to use for the join event
Returns:
A dict representing the payload of the request.
"""
return {
"requester": requester.serialize(),
@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request(self, request, room_id, user_id):
async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
):
) -> JsonDict:
"""
Args:
invite_event_id: ID of the invite to be rejected
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
invite_event_id: The ID of the invite to be rejected.
txn_id: Optional transaction ID supplied by the client
requester: User making the rejection request, according to the access token
content: Additional content to include in the rejection event.
Normally an empty dict.
Returns:
A dict representing the payload of the request.
"""
return {
"txn_id": txn_id,
@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request(self, request, invite_event_id):
async def _handle_request( # type: ignore
self, request: Request, invite_event_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
async def _serialize_payload(room_id, user_id, change):
async def _serialize_payload( # type: ignore
room_id: str, user_id: str, change: str
) -> JsonDict:
"""
Args:
room_id (str)
user_id (str)
change (str): "left"
room_id: The ID of the room.
user_id: The ID of the user.
change: "left"
Returns:
A dict representing the payload of the request.
"""
assert change == "left"
return {}
def _handle_request(self, request, room_id, user_id, change):
def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str, change: str
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)

View file

@ -21,11 +21,7 @@ import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
historical_admin_path_patterns,
)
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
@ -61,6 +57,7 @@ from synapse.rest.admin.users import (
UserRestServletV2,
UsersRestServlet,
UsersRestServletV2,
UserTokenRestServlet,
WhoisRestServlet,
)
from synapse.types import RoomStreamToken
@ -83,7 +80,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns(
PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
@ -168,9 +165,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns(
"/purge_history_status/(?P<purge_id>[^/]+)"
)
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
def __init__(self, hs):
"""
@ -223,6 +218,7 @@ def register_servlets(hs, http_server):
UserAdminServlet(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

View file

@ -22,28 +22,6 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
def historical_admin_path_patterns(path_regex):
"""Returns the list of patterns for an admin endpoint, including historical ones
This is a backwards-compatibility hack. Previously, the Admin API was exposed at
various paths under /_matrix/client. This function returns a list of patterns
matching those paths (as well as the new one), so that existing scripts which rely
on the endpoints being available there are not broken.
Note that this should only be used for existing endpoints: new ones should just
register for the /_synapse/admin path.
"""
return [
re.compile(prefix + path_regex)
for prefix in (
"^/_synapse/admin/v1",
"^/_matrix/client/api/v1/admin",
"^/_matrix/client/unstable/admin",
"^/_matrix/client/r0/admin",
)
]
def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint

View file

@ -16,10 +16,7 @@ import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.admin._base import (
assert_user_is_admin,
historical_admin_path_patterns,
)
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
logger = logging.getLogger(__name__)
@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
self.group_server = hs.get_groups_server_handler()

View file

@ -22,7 +22,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
)
logger = logging.getLogger(__name__)
@ -34,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = (
historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
# This path kept around for legacy reasons
historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
@ -63,9 +62,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
PATTERNS = historical_admin_path_patterns(
"/user/(?P<user_id>[^/]+)/media/quarantine"
)
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
self.store = hs.get_datastore()
@ -90,7 +87,7 @@ class QuarantineMediaByID(RestServlet):
it via this server.
"""
PATTERNS = historical_admin_path_patterns(
PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
@ -116,7 +113,7 @@ class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
self.store = hs.get_datastore()
@ -134,7 +131,7 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/purge_media_cache")
PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()

View file

@ -29,7 +29,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@ -44,7 +43,7 @@ class ShutdownRoomRestServlet(RestServlet):
joined to the new room.
"""
PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs):
self.hs = hs
@ -71,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
class DeleteRoomRestServlet(RestServlet):
"""Delete a room from server. It is a combination and improvement of
shut down and purge room.
"""Delete a room from server.
It is a combination and improvement of shutdown and purge room.
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
If desired any local aliases will be repointed to a new room
created by `new_room_user_id` and kicked users will be auto
created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
It will remove all trace of a room from the database.
If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
@ -111,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
force_purge = content.get("force_purge", False)
if not isinstance(force_purge, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'force_purge' must be a boolean, if given",
Codes.BAD_JSON,
)
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@ -122,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
# Purge room
if purge:
await self.pagination_handler.purge_room(room_id)
await self.pagination_handler.purge_room(room_id, force=force_purge)
return (200, ret)
@ -309,7 +320,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
fake_requester = create_requester(target_user)
fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)

View file

@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -33,10 +33,13 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
@ -52,7 +55,7 @@ _GET_PUSHERS_ALLOWED_KEYS = {
class UsersRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
self.hs = hs
@ -335,7 +338,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
PATTERNS = historical_admin_path_patterns("/register")
PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
@ -458,7 +461,14 @@ class UserRegisterServlet(RestServlet):
class WhoisRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
path_regex = "/whois/(?P<user_id>[^/]*)$"
PATTERNS = (
admin_patterns(path_regex)
+
# URL for spec reason
# https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
client_patterns("/admin" + path_regex, v1=True)
)
def __init__(self, hs):
self.hs = hs
@ -482,7 +492,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@ -513,7 +523,7 @@ class DeactivateAccountRestServlet(RestServlet):
class AccountValidityRenewServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs):
"""
@ -556,9 +566,7 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
PATTERNS = historical_admin_path_patterns(
"/reset_password/(?P<target_user_id>[^/]*)"
)
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
@ -600,7 +608,7 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.hs = hs
@ -828,3 +836,52 @@ class UserMediaRestServlet(RestServlet):
ret["next_token"] = start + len(media)
return 200, ret
class UserTokenRestServlet(RestServlet):
"""An admin API for logging in as a user.
Example:
POST /_synapse/admin/v1/users/@test:example.com/login
{}
200 OK
{
"access_token": "<some_token>"
}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be logged in as")
body = parse_json_object_from_request(request, allow_empty_body=True)
valid_until_ms = body.get("valid_until_ms")
if valid_until_ms and not isinstance(valid_until_ms, int):
raise SynapseError(400, "'valid_until_ms' parameter must be an int")
if auth_user.to_string() == user_id:
raise SynapseError(400, "Cannot use admin API to login as self")
token = await self.auth_handler.get_access_token_for_user_id(
user_id=auth_user.to_string(),
device_id=None,
valid_until_ms=valid_until_ms,
puppets_user_id=user_id,
)
return 200, {"access_token": token}

View file

@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
from synapse.handlers.auth import (
convert_client_dict_legacy_fields_to_identifier,
login_id_phone_to_thirdparty,
)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
self._failed_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
def on_GET(self, request: SynapseRequest):
flows = []
@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
def _get_qualified_user_id(self, identifier):
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
if identifier["user"].startswith("@"):
return identifier["user"]
else:
return UserID(identifier["user"], self.hs.hostname).to_string()
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
logger.info(
"Got appservice login request with identifier: %r",
login_submission.get("identifier"),
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
if not isinstance(identifier, dict):
raise SynapseError(
400, "Invalid identifier in login submission", Codes.INVALID_PARAM
)
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
qualified_user_id = self._get_qualified_user_id(identifier)
# this login flow only supports identifiers of type "m.id.user".
if identifier.get("type") != "m.id.user":
raise SynapseError(
400, "Unknown login identifier type", Codes.INVALID_PARAM
)
user = identifier.get("user")
if not isinstance(user, str):
raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
if user.startswith("@"):
qualified_user_id = user
else:
qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_phone_to_thirdparty(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
address = identifier.get("address")
medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Check for login providers that support 3pid login types
(
canonical_user_id,
callback_3pid,
) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
# No password providers were able to handle this 3pid
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
# We mark that we've failed to log in here, as
# `check_password_provider_3pid` might have returned `None` due
# to an incorrect password, rather than the account not
# existing.
#
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
qualified_user_id = self._get_qualified_user_id(identifier)
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
)
try:
canonical_user_id, callback = await self.auth_handler.validate_login(
identifier["user"], login_submission
login_submission, ratelimit=True
)
except LoginError:
# The user has failed to log in, so we need to update the rate
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
raise
result = await self._complete_login(
canonical_user_id, login_submission, callback
)

View file

@ -18,7 +18,7 @@
import logging
import re
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID,
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)

View file

@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)

View file

@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(

View file

@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
)
with context:
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
since_token=since_token,
timeout=timeout,

View file

@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
self.clock = hs.clock
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)

View file

@ -27,7 +27,8 @@ import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
import twisted
import twisted.internet.base
import twisted.internet.tcp
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
@ -89,6 +90,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@ -145,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`"
)
depname = builder.__name__[len("get_") :]
# get_attr -> _attr
depname = builder.__name__[len("get") :]
building = [False]
@ -233,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
self.clock = Clock(reactor)
self.distributor = Distributor()
self.registration_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=config.rc_registration.per_second,
burst_count=config.rc_registration.burst_count,
)
self.version_string = version_string
self.datastores = None # type: Optional[Databases]
@ -299,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
@cache_in_self
def get_clock(self) -> Clock:
return self.clock
return Clock(self._reactor)
def get_datastore(self) -> DataStore:
if not self.datastores:
@ -317,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig:
return self.config
@cache_in_self
def get_distributor(self) -> Distributor:
return self.distributor
return Distributor()
@cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter
return Ratelimiter(
clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count,
)
@cache_in_self
def get_federation_client(self) -> FederationClient:
@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta):
else:
return FollowerTypingHandler(self)
@cache_in_self
def get_sso_handler(self) -> SsoHandler:
return SsoHandler(self)
@cache_in_self
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
@ -681,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self
def get_module_api(self) -> ModuleApi:

View file

@ -39,6 +39,7 @@ class ServerNoticesManager:
self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id
self._server_name = hs.hostname
self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.server_notices_mxid
@ -72,7 +73,9 @@ class ServerNoticesManager:
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid)
requester = create_requester(
system_mxid, authenticated_entity=self._server_name
)
logger.info("Sending server notice to %s", user_id)
@ -145,7 +148,9 @@ class ServerNoticesManager:
"avatar_url": self._config.server_notices_mxid_avatar_url,
}
requester = create_requester(self.server_notices_mxid)
requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
info, _ = await self._room_creation_handler.create_room(
requester,
config={
@ -174,7 +179,9 @@ class ServerNoticesManager:
user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to.
"""
requester = create_requester(self.server_notices_mxid)
requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.

View file

@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
"event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
"event_push_actions",
"event_search",
"events",

View file

@ -278,7 +278,8 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids.
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
Args:
to_key: Max stream id to fetch receipts upto.
@ -294,12 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
txn.execute(sql, [to_key])

View file

@ -1110,6 +1110,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
) -> int:
"""Adds an access token for the given user.
@ -1133,6 +1134,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
},
desc="add_access_token_to_user",
)

View file

@ -1240,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
):
"""
When we receive an invite over federation, store the version of the room if we
don't already know the room version.
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
"""
await self.db_pool.simple_upsert(
desc="maybe_store_room_on_invite",
desc="maybe_store_room_on_outlier_membership",
table="rooms",
keyvalues={"room_id": room_id},
values={},

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@ -350,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
"""Retrieve the current local membership state and event ID for a user in a room.
Args:
user_id: The ID of the user.
room_id: The ID of the room.
Returns:
A tuple of (membership_type, event_id). Both will be None if a
room_id/user_id pair is not found.
"""
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
"Cannot call 'get_local_current_membership_for_user_in_room' on "
"non-local user %s" % (user_id,),
)
results_dict = await self.db_pool.simple_select_one(
"local_current_membership",
{"room_id": room_id, "user_id": user_id},
("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
)
if not results_dict:
return None, None
return results_dict.get("membership"), results_dict.get("event_id")
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str

View file

@ -20,14 +20,14 @@
*/
-- add new index that includes method to local media
INSERT INTO background_updates (update_name, progress_json) VALUES
('local_media_repository_thumbnails_method_idx', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5807, 'local_media_repository_thumbnails_method_idx', '{}');
-- add new index that includes method to remote media
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
-- drop old index
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');

View file

@ -28,5 +28,5 @@
-- functionality as the old one. This effectively restarts the background job
-- from the beginning, without running it twice in a row, supporting both
-- upgrade usecases.
INSERT INTO background_updates (update_name, progress_json) VALUES
('populate_stats_process_rooms_2', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5812, 'populate_stats_process_rooms_2', '{}');

View file

@ -1,2 +1,2 @@
INSERT INTO background_updates (update_name, progress_json) VALUES
('users_have_local_media', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5822, 'users_have_local_media', '{}');

View file

@ -13,5 +13,5 @@
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('e2e_cross_signing_keys_idx', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5823, 'e2e_cross_signing_keys_idx', '{}');

View file

@ -0,0 +1,19 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- this index is essentially redundant. The only time it was ever used was when purging
-- rooms - and Synapse 1.24 will change that.
DROP INDEX IF EXISTS event_json_room_id;

View file

@ -317,14 +317,14 @@ mxid_localpart_allowed_characters = set(
)
def contains_invalid_mxid_characters(localpart):
def contains_invalid_mxid_characters(localpart: str) -> bool:
"""Check for characters not allowed in an mxid or groupid localpart
Args:
localpart (basestring): the localpart to be checked
localpart: the localpart to be checked
Returns:
bool: True if there are any naughty characters
True if there are any naughty characters
"""
return any(c not in mxid_localpart_allowed_characters for c in localpart)

7
synctl
View file

@ -358,6 +358,13 @@ def main():
for worker in workers:
env = os.environ.copy()
# Skip starting a worker if its already running
if os.path.exists(worker.pidfile) and pid_running(
int(open(worker.pidfile).read())
):
print(worker.app + " already running")
continue
if worker.cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)

View file

@ -282,7 +282,11 @@ class AuthTestCase(unittest.TestCase):
)
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
user_id=USER_ID,
token=token,
device_id="DEVICE",
valid_until_ms=None,
puppets_user_id=None,
)
def get_user(tok):

View file

@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer
from tests.server import make_request
from tests.unittest import HomeserverTestCase
@ -55,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
_, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@ -77,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
_, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)

View file

@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
from tests.server import make_request
from tests.unittest import HomeserverTestCase
@ -66,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
_, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
self.assertEqual(channel.code, 401)
@ -115,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
_, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
self.assertEqual(channel.code, 401)

View file

@ -51,7 +51,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@ -64,7 +63,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)

View file

@ -51,7 +51,6 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
self.render(request)
self.assertEquals(400, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@ -99,7 +98,6 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(
@ -132,7 +130,6 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.render(request)
self.assertEquals(403, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")

View file

@ -40,7 +40,6 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/publicRooms"
)
self.render(request)
self.assertEquals(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True})
@ -48,5 +47,4 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/publicRooms"
)
self.render(request)
self.assertEquals(200, channel.code)

View file

@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
self.hs.clock.now = 5000
self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000
self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)

View file

@ -412,7 +412,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
self.render(request)
self.assertEquals(403, channel.code, channel.result)
def test_allowed(self):
@ -423,7 +422,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
@ -438,7 +436,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.room_list_handler = hs.get_room_list_handler()
@ -452,7 +449,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list is enabled so we should get some results
request, channel = self.make_request("GET", b"publicRooms")
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@ -461,7 +457,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list disabled so we should get no results
request, channel = self.make_request("GET", b"publicRooms")
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
@ -470,5 +465,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.render(request)
self.assertEquals(403, channel.code, channel.result)

View file

@ -59,7 +59,6 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
d = self.handler.on_exchange_third_party_invite_request(
room_id=room_id,
event_dict={
"type": EventTypes.Member,
"room_id": room_id,

View file

@ -209,5 +209,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
self.render(request)
self.assertEqual(int(channel.result["code"]), 403)

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from urllib.parse import parse_qs, urlparse
@ -24,12 +23,8 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from synapse.handlers.oidc_handler import (
MappingException,
OidcError,
OidcHandler,
OidcMappingProvider,
)
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.unittest import HomeserverTestCase, override_config
@ -94,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
async def map_user_attributes(self, userinfo, token, failures):
return {
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
@ -124,22 +127,16 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
config = self.default_config()
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
oidc_config["issuer"] = ISSUER
oidc_config["scopes"] = SCOPES
oidc_config["user_mapping_provider"] = {
"module": __name__ + ".TestMappingProvider",
oidc_config = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@ -147,13 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
http_client=self.http_client,
proxied_http_client=self.http_client,
config=config,
)
return config
self.handler = OidcHandler(hs)
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
sso_handler.render_error = self.render_error
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@ -161,12 +169,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
args = self.handler._render_error.call_args[0]
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.handler._render_error.reset_mock()
self.render_error.reset_mock()
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@ -356,7 +364,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
@ -387,7 +394,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar",
}
user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@ -435,7 +441,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock(
@ -469,7 +475,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@ -485,7 +491,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie
@ -699,19 +704,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
),
MappingException,
)
self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
self.assertEqual(
str(e.value),
"Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
user4 = UserID.from_string("@test_user_4:test")
user = UserID.from_string("@test_user:test")
self.get_success(
store.register_user(user_id=user4.to_string(), password_hash=None)
store.register_user(user_id=user.to_string(), password_hash=None)
)
# Map a user via SSO.
userinfo = {
"sub": "test4",
"username": "test_user_4",
"sub": "test",
"username": "test_user",
}
token = {}
mxid = self.get_success(
@ -719,4 +729,137 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user_4:test")
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
# in this case we'll just use the same username. A more realistic example
# would be subs which are email addresses, and mapping from the localpart
# of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
userinfo = {
"sub": "test1",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
self.get_success(
store.register_user(user_id=user2.to_string(), password_hash=None)
)
user2_caps = UserID.from_string("@test_USER_2:test")
self.get_success(
store.register_user(user_id=user2_caps.to_string(), password_hash=None)
)
# Attempting to login without matching a name exactly is an error.
userinfo = {
"sub": "test2",
"username": "TEST_USER_2",
}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertTrue(
str(e.value).startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
# Logging in when matching a name exactly should work.
user2 = UserID.from_string("@TEST_USER_2:test")
self.get_success(
store.register_user(user_id=user2.to_string(), password_hash=None)
)
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
userinfo = {
"sub": "test2",
"username": "föö",
}
token = {}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
@override_config(
{
"oidc_config": {
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures"
}
}
}
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
userinfo = {
"sub": "test",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)
for i in range(1, 3):
self.get_success(
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
)
# Now attempt to map to a username, this will fail since all potential usernames are taken.
userinfo = {
"sub": "tester",
"username": "tester",
}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
)

View file

@ -0,0 +1,580 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the password_auth_provider interface"""
from typing import Any, Type, Union
from mock import Mock
from twisted.internet import defer
import synapse
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import devices
from synapse.types import JsonDict
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
mock_password_provider = Mock()
class PasswordOnlyAuthProvider:
"""A password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def check_password(self, *args):
return mock_password_provider.check_password(*args)
class CustomAuthProvider:
"""A password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def get_supported_login_types(self):
return {"test.login_type": ["test_field"]}
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class PasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def get_supported_login_types(self):
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
]
def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_password_only_auth_provider_login(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
)
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
module_api = self.hs.get_module_api()
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Make another request providing the UI auth flow.
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_local_user_fallback_login(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Wrong password
channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "xxx"
)
mock_password_provider.reset_mock()
# Right password
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login(self):
"""localdb_enabled can block login with the local password
"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth(self):
"""localdb_enabled can block ui auth with the local password
"""
self.register_user("localuser", "localpass")
# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
self.login("localuser", "p", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# m.login.password UIA is permitted because the auth provider allows it,
# even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password"}, {"type": "test.login_type"}]
+ ADDITIONAL_LOGIN_FLOWS,
)
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
mock_password_provider.reset_mock()
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
"@ MALFORMED! :bz"
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
# missing param
body = {
"auth": {
"type": "test.login_type",
"identifier": {"type": "m.id.user", "user": "localuser"},
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
# there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
# use it...
self.assertIn("Missing parameters", channel.json_body["error"])
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
mock_password_provider.reset_mock()
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test"
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
# check the args to the callback
callback.assert_called_once()
call_args, call_kwargs = callback.call_args
# should be one positional arg
self.assertEqual(len(call_args), 1)
self.assertEqual(call_args[0]["user_id"], "@user:bz")
for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0])
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test"
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, 200, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected. In particular, "password" should *not*
# be present.
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
mock_password_provider.reset_mock()
# check that auth with password is rejected
body = {
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "localuser"},
"password": "localpass",
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
self.assertEqual(
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
body["auth"]["type"] = "test.login_type"
body["auth"]["test_field"] = "x"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
def _get_login_flows(self) -> JsonDict:
_, channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
_, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
return channel.json_body["session"]
def _authed_delete_device(
self,
access_token: str,
device_id: str,
session: str,
user_id: str,
password: str,
) -> FakeChannel:
"""Make a delete device request, authenticating with the given uid/password"""
return self._delete_device(
access_token,
device_id,
{
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": user_id},
"password": password,
"session": session,
},
},
)
def _delete_device(
self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
_, channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel

168
tests/handlers/test_saml.py Normal file
View file

@ -0,0 +1,168 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attr
from synapse.handlers.sso import MappingException
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
class TestMappingProvider:
def __init__(self, config, module):
pass
@staticmethod
def parse_config(config):
return
@staticmethod
def get_saml_attributes(config):
return {"uid"}, {"displayName"}
def get_remote_user_id(self, saml_response, client_redirect_url):
return saml_response.ava["uid"]
def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
):
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None}
class SamlHandlerTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
saml_config = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
# override_config works as expected.
saml_config.update(config.get("saml2_config", {}))
config["saml2_config"] = saml_config
return config
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
# Reduce the number of attempts when generating MXIDs.
sso_handler = hs.get_sso_handler()
sso_handler._MAP_USERNAME_RETRIES = 3
return hs
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
# The redirect_url doesn't matter with the default user mapping provider.
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
"""Existing users can log in with SAML account."""
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
redirect_url = ""
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
# Register all of the potential mxids for a particular SAML username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)
for i in range(1, 3):
self.get_success(
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
)
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
)

View file

@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True
@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
self.get_success(
self.sync_handler.wait_for_sync_for_user(requester, sync_config)
)
# Test that global lock works
self.auth_blocking._hs_disabled = True
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

View file

@ -228,7 +228,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
),
federation_auth_origin=b"farm",
)
self.render(request)
self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])

View file

@ -537,7 +537,6 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) > 0)
@ -546,6 +545,5 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)

View file

@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
@ -43,20 +44,18 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View file

@ -94,12 +94,13 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertFalse(hasattr(event, "state_key"))
self.assertDictEqual(event.content, content)
expected_requester = create_requester(
user_id, authenticated_entity=self.hs.hostname
)
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
create_requester(user_id),
event_dict,
ratelimit=False,
ignore_shadow_ban=True,
expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
)
# Create and send a state event
@ -128,7 +129,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
create_requester(user_id),
expected_requester,
{
"type": "m.room.power_levels",
"content": content,

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet.defer import Deferred
@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import receipts
from tests.unittest import HomeserverTestCase
from tests.unittest import HomeserverTestCase, override_config
class HTTPPusherTests(HomeserverTestCase):
@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
receipts.register_servlets,
]
user_id = True
hijack_auth = False
@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
def test_push_unread_count_group_by_room(self):
"""
The HTTP pusher will group unread count by number of unread rooms.
"""
# Carry out common push count tests and setup
self._test_push_unread_count()
# Carry out our option-value specific test
#
# This push should still only contain an unread count of 1 (for 1 unread room)
self.assertEqual(
self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
)
@override_config({"push": {"group_unread_count_by_room": False}})
def test_push_unread_count_message_count(self):
"""
The HTTP pusher will send the total unread message count.
"""
# Carry out common push count tests and setup
self._test_push_unread_count()
# Carry out our option-value specific test
#
# We're counting every unread message, so there should now be 4 since the
# last read receipt
self.assertEqual(
self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
)
def _test_push_unread_count(self):
"""
Tests that the correct unread count appears in sent push notifications
Note that:
* Sending messages will cause push notifications to go out to relevant users
* Sending a read receipt will cause a "badge update" notification to go out to
the user that sent the receipt
"""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Register the user who sends the message
other_user_id = self.register_user("other_user", "pass")
other_access_token = self.login("other_user", "pass")
# Create a room (as other_user)
room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
# The user to get notified joins
self.helper.join(room=room_id, user=user_id, tok=access_token)
# Register the pusher
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
)
)
# Send a message
response = self.helper.send(
room_id, body="Hello there!", tok=other_access_token
)
# To get an unread count, the user who is getting notified has to have a read
# position in the room. We'll set the read position to this event in a moment
first_message_event_id = response["event_id"]
# Advance time a bit (so the pusher will register something has happened) and
# make the push succeed
self.push_attempts[0][0].callback({})
self.pump()
# Check our push made it
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
# Check that the unread count for the room is 0
#
# The unread count is zero as the user has no read receipt in the room yet
self.assertEqual(
self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
)
# Now set the user's read receipt position to the first event
#
# This will actually trigger a new notification to be sent out so that
# even if the user does not receive another message, their unread
# count goes down
request, channel = self.make_request(
"POST",
"/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
{},
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Advance time and make the push succeed
self.push_attempts[1][0].callback({})
self.pump()
# Unread count is still zero as we've read the only message in the room
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(
self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
)
# Send another message
self.helper.send(
room_id, body="How's the weather today?", tok=other_access_token
)
# Advance time and make the push succeed
self.push_attempts[2][0].callback({})
self.pump()
# This push should contain an unread count of 1 as there's now been one
# message since our last read receipt
self.assertEqual(len(self.push_attempts), 3)
self.assertEqual(
self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
)
# Since we're grouping by room, sending more messages shouldn't increase the
# unread count, as they're all being sent in the same room
self.helper.send(room_id, body="Hello?", tok=other_access_token)
# Advance time and make the push succeed
self.pump()
self.push_attempts[3][0].callback({})
self.helper.send(room_id, body="Hello??", tok=other_access_token)
# Advance time and make the push succeed
self.pump()
self.push_attempts[4][0].callback({})
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
# Advance time and make the push succeed
self.pump()
self.push_attempts[5][0].callback({})
self.assertEqual(len(self.push_attempts), 6)

Some files were not shown because too many files have changed in this diff Show more