mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-08-08 11:42:12 -04:00
Merge remote-tracking branch 'upstream/release-v1.35'
This commit is contained in:
commit
4740b83c39
98 changed files with 6440 additions and 2105 deletions
|
@ -3,7 +3,7 @@
|
||||||
# CI's Docker setup at the point where this file is considered.
|
# CI's Docker setup at the point where this file is considered.
|
||||||
server_name: "localhost:8800"
|
server_name: "localhost:8800"
|
||||||
|
|
||||||
signing_key_path: "/src/.buildkite/test.signing.key"
|
signing_key_path: ".buildkite/test.signing.key"
|
||||||
|
|
||||||
report_stats: false
|
report_stats: false
|
||||||
|
|
||||||
|
@ -16,6 +16,4 @@ database:
|
||||||
database: synapse
|
database: synapse
|
||||||
|
|
||||||
# Suppress the key server warning.
|
# Suppress the key server warning.
|
||||||
trusted_key_servers:
|
trusted_key_servers: []
|
||||||
- server_name: "matrix.org"
|
|
||||||
suppress_key_server_warning: true
|
|
||||||
|
|
|
@ -33,6 +33,10 @@ scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml
|
||||||
echo "+++ Run synapse_port_db against test database"
|
echo "+++ Run synapse_port_db against test database"
|
||||||
coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
|
coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
|
||||||
|
|
||||||
|
# We should be able to run twice against the same database.
|
||||||
|
echo "+++ Run synapse_port_db a second time"
|
||||||
|
coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
|
||||||
|
|
||||||
#####
|
#####
|
||||||
|
|
||||||
# Now do the same again, on an empty database.
|
# Now do the same again, on an empty database.
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# schema and run background updates on it.
|
# schema and run background updates on it.
|
||||||
server_name: "localhost:8800"
|
server_name: "localhost:8800"
|
||||||
|
|
||||||
signing_key_path: "/src/.buildkite/test.signing.key"
|
signing_key_path: ".buildkite/test.signing.key"
|
||||||
|
|
||||||
report_stats: false
|
report_stats: false
|
||||||
|
|
||||||
|
@ -13,6 +13,4 @@ database:
|
||||||
database: ".buildkite/test_db.db"
|
database: ".buildkite/test_db.db"
|
||||||
|
|
||||||
# Suppress the key server warning.
|
# Suppress the key server warning.
|
||||||
trusted_key_servers:
|
trusted_key_servers: []
|
||||||
- server_name: "matrix.org"
|
|
||||||
suppress_key_server_warning: true
|
|
||||||
|
|
64
CHANGES.md
64
CHANGES.md
|
@ -1,3 +1,67 @@
|
||||||
|
Synapse 1.35.0rc1 (2021-05-25)
|
||||||
|
==============================
|
||||||
|
|
||||||
|
Features
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Add experimental support to allow a user who could join a restricted room to view it in the spaces summary. ([\#9922](https://github.com/matrix-org/synapse/issues/9922), [\#10007](https://github.com/matrix-org/synapse/issues/10007), [\#10038](https://github.com/matrix-org/synapse/issues/10038))
|
||||||
|
- Reduce memory usage when joining very large rooms over federation. ([\#9958](https://github.com/matrix-org/synapse/issues/9958))
|
||||||
|
- Add a configuration option which allows enabling opentracing by user id. ([\#9978](https://github.com/matrix-org/synapse/issues/9978))
|
||||||
|
- Enable experimental support for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946) (spaces summary API) and [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) (restricted join rules) by default. ([\#10011](https://github.com/matrix-org/synapse/issues/10011))
|
||||||
|
|
||||||
|
|
||||||
|
Bugfixes
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Fix a bug introduced in v1.26.0 which meant that `synapse_port_db` would not correctly initialise some postgres sequences, requiring manual updates afterwards. ([\#9991](https://github.com/matrix-org/synapse/issues/9991))
|
||||||
|
- Fix `synctl`'s `--no-daemonize` parameter to work correctly with worker processes. ([\#9995](https://github.com/matrix-org/synapse/issues/9995))
|
||||||
|
- Fix a validation bug introduced in v1.34.0 in the ordering of spaces in the space summary API. ([\#10002](https://github.com/matrix-org/synapse/issues/10002))
|
||||||
|
- Fixed deletion of new presence stream states from database. ([\#10014](https://github.com/matrix-org/synapse/issues/10014), [\#10033](https://github.com/matrix-org/synapse/issues/10033))
|
||||||
|
- Fixed a bug with very high resolution image uploads throwing internal server errors. ([\#10029](https://github.com/matrix-org/synapse/issues/10029))
|
||||||
|
|
||||||
|
|
||||||
|
Updates to the Docker image
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
- Fix bug introduced in Synapse 1.33.0 which caused a `Permission denied: '/homeserver.log'` error when starting Synapse with the generated log configuration. Contributed by Sergio Miguéns Iglesias. ([\#10045](https://github.com/matrix-org/synapse/issues/10045))
|
||||||
|
|
||||||
|
|
||||||
|
Improved Documentation
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
- Add hardened systemd files as proposed in [#9760](https://github.com/matrix-org/synapse/issues/9760) and added them to `contrib/`. Change the docs to reflect the presence of these files. ([\#9803](https://github.com/matrix-org/synapse/issues/9803))
|
||||||
|
- Clarify documentation around SSO mapping providers generating unique IDs and localparts. ([\#9980](https://github.com/matrix-org/synapse/issues/9980))
|
||||||
|
- Updates to the PostgreSQL documentation (`postgres.md`). ([\#9988](https://github.com/matrix-org/synapse/issues/9988), [\#9989](https://github.com/matrix-org/synapse/issues/9989))
|
||||||
|
- Fix broken link in user directory documentation. Contributed by @junquera. ([\#10016](https://github.com/matrix-org/synapse/issues/10016))
|
||||||
|
- Add missing room state entry to the table of contents of room admin API. ([\#10043](https://github.com/matrix-org/synapse/issues/10043))
|
||||||
|
|
||||||
|
|
||||||
|
Deprecations and Removals
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
- Removed support for the deprecated `tls_fingerprints` configuration setting. Contributed by Jerin J Titus. ([\#9280](https://github.com/matrix-org/synapse/issues/9280))
|
||||||
|
|
||||||
|
|
||||||
|
Internal Changes
|
||||||
|
----------------
|
||||||
|
|
||||||
|
- Allow sending full presence to users via workers other than the one that called `ModuleApi.send_local_online_presence_to`. ([\#9823](https://github.com/matrix-org/synapse/issues/9823))
|
||||||
|
- Update comments in the space summary handler. ([\#9974](https://github.com/matrix-org/synapse/issues/9974))
|
||||||
|
- Minor enhancements to the `@cachedList` descriptor. ([\#9975](https://github.com/matrix-org/synapse/issues/9975))
|
||||||
|
- Split multipart email sending into a dedicated handler. ([\#9977](https://github.com/matrix-org/synapse/issues/9977))
|
||||||
|
- Run `black` on files in the `scripts` directory. ([\#9981](https://github.com/matrix-org/synapse/issues/9981))
|
||||||
|
- Add missing type hints to `synapse.util` module. ([\#9982](https://github.com/matrix-org/synapse/issues/9982))
|
||||||
|
- Simplify a few helper functions. ([\#9984](https://github.com/matrix-org/synapse/issues/9984), [\#9985](https://github.com/matrix-org/synapse/issues/9985), [\#9986](https://github.com/matrix-org/synapse/issues/9986))
|
||||||
|
- Remove unnecessary property from SQLBaseStore. ([\#9987](https://github.com/matrix-org/synapse/issues/9987))
|
||||||
|
- Remove `keylen` param on `LruCache`. ([\#9993](https://github.com/matrix-org/synapse/issues/9993))
|
||||||
|
- Update the Grafana dashboard in `contrib/`. ([\#10001](https://github.com/matrix-org/synapse/issues/10001))
|
||||||
|
- Add a batching queue implementation. ([\#10017](https://github.com/matrix-org/synapse/issues/10017))
|
||||||
|
- Reduce memory usage when verifying signatures on large numbers of events at once. ([\#10018](https://github.com/matrix-org/synapse/issues/10018))
|
||||||
|
- Properly invalidate caches for destination retry timings every (instead of expiring entries every 5 minutes). ([\#10036](https://github.com/matrix-org/synapse/issues/10036))
|
||||||
|
- Fix running complement tests with Synapse workers. ([\#10039](https://github.com/matrix-org/synapse/issues/10039))
|
||||||
|
- Fix typo in `get_state_ids_for_event` docstring where the return type was incorrect. ([\#10050](https://github.com/matrix-org/synapse/issues/10050))
|
||||||
|
|
||||||
|
|
||||||
Synapse 1.34.0 (2021-05-17)
|
Synapse 1.34.0 (2021-05-17)
|
||||||
===========================
|
===========================
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
71
contrib/systemd/override-hardened.conf
Normal file
71
contrib/systemd/override-hardened.conf
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
[Service]
|
||||||
|
# The following directives give the synapse service R/W access to:
|
||||||
|
# - /run/matrix-synapse
|
||||||
|
# - /var/lib/matrix-synapse
|
||||||
|
# - /var/log/matrix-synapse
|
||||||
|
|
||||||
|
RuntimeDirectory=matrix-synapse
|
||||||
|
StateDirectory=matrix-synapse
|
||||||
|
LogsDirectory=matrix-synapse
|
||||||
|
|
||||||
|
######################
|
||||||
|
## Security Sandbox ##
|
||||||
|
######################
|
||||||
|
|
||||||
|
# Make sure that the service has its own unshared tmpfs at /tmp and that it
|
||||||
|
# cannot see or change any real devices
|
||||||
|
PrivateTmp=true
|
||||||
|
PrivateDevices=true
|
||||||
|
|
||||||
|
# We give no capabilities to a service by default
|
||||||
|
CapabilityBoundingSet=
|
||||||
|
AmbientCapabilities=
|
||||||
|
|
||||||
|
# Protect the following from modification:
|
||||||
|
# - The entire filesystem
|
||||||
|
# - sysctl settings and loaded kernel modules
|
||||||
|
# - No modifications allowed to Control Groups
|
||||||
|
# - Hostname
|
||||||
|
# - System Clock
|
||||||
|
ProtectSystem=strict
|
||||||
|
ProtectKernelTunables=true
|
||||||
|
ProtectKernelModules=true
|
||||||
|
ProtectControlGroups=true
|
||||||
|
ProtectClock=true
|
||||||
|
ProtectHostname=true
|
||||||
|
|
||||||
|
# Prevent access to the following:
|
||||||
|
# - /home directory
|
||||||
|
# - Kernel logs
|
||||||
|
ProtectHome=tmpfs
|
||||||
|
ProtectKernelLogs=true
|
||||||
|
|
||||||
|
# Make sure that the process can only see PIDs and process details of itself,
|
||||||
|
# and the second option disables seeing details of things like system load and
|
||||||
|
# I/O etc
|
||||||
|
ProtectProc=invisible
|
||||||
|
ProcSubset=pid
|
||||||
|
|
||||||
|
# While not needed, we set these options explicitly
|
||||||
|
# - This process has been given access to the host network
|
||||||
|
# - It can also communicate with any IP Address
|
||||||
|
PrivateNetwork=false
|
||||||
|
RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX
|
||||||
|
IPAddressAllow=any
|
||||||
|
|
||||||
|
# Restrict system calls to a sane bunch
|
||||||
|
SystemCallArchitectures=native
|
||||||
|
SystemCallFilter=@system-service
|
||||||
|
SystemCallFilter=~@privileged @resources @obsolete
|
||||||
|
|
||||||
|
# Misc restrictions
|
||||||
|
# - Since the process is a python process it needs to be able to write and
|
||||||
|
# execute memory regions, so we set MemoryDenyWriteExecute to false
|
||||||
|
RestrictSUIDSGID=true
|
||||||
|
RemoveIPC=true
|
||||||
|
NoNewPrivileges=true
|
||||||
|
RestrictRealtime=true
|
||||||
|
RestrictNamespaces=true
|
||||||
|
LockPersonality=true
|
||||||
|
PrivateUsers=true
|
||||||
|
MemoryDenyWriteExecute=false
|
|
@ -9,10 +9,11 @@ formatters:
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
handlers:
|
handlers:
|
||||||
|
{% if LOG_FILE_PATH %}
|
||||||
file:
|
file:
|
||||||
class: logging.handlers.TimedRotatingFileHandler
|
class: logging.handlers.TimedRotatingFileHandler
|
||||||
formatter: precise
|
formatter: precise
|
||||||
filename: {{ LOG_FILE_PATH or "homeserver.log" }}
|
filename: {{ LOG_FILE_PATH }}
|
||||||
when: "midnight"
|
when: "midnight"
|
||||||
backupCount: 6 # Does not include the current log file.
|
backupCount: 6 # Does not include the current log file.
|
||||||
encoding: utf8
|
encoding: utf8
|
||||||
|
@ -29,6 +30,7 @@ handlers:
|
||||||
# be written to disk.
|
# be written to disk.
|
||||||
capacity: 10
|
capacity: 10
|
||||||
flushLevel: 30 # Flush for WARNING logs as well
|
flushLevel: 30 # Flush for WARNING logs as well
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
|
|
|
@ -184,18 +184,18 @@ stderr_logfile_maxbytes=0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NGINX_LOCATION_CONFIG_BLOCK = """
|
NGINX_LOCATION_CONFIG_BLOCK = """
|
||||||
location ~* {endpoint} {
|
location ~* {endpoint} {{
|
||||||
proxy_pass {upstream};
|
proxy_pass {upstream};
|
||||||
proxy_set_header X-Forwarded-For $remote_addr;
|
proxy_set_header X-Forwarded-For $remote_addr;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
proxy_set_header Host $host;
|
proxy_set_header Host $host;
|
||||||
}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NGINX_UPSTREAM_CONFIG_BLOCK = """
|
NGINX_UPSTREAM_CONFIG_BLOCK = """
|
||||||
upstream {upstream_worker_type} {
|
upstream {upstream_worker_type} {{
|
||||||
{body}
|
{body}
|
||||||
}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
* [Usage](#usage)
|
* [Usage](#usage)
|
||||||
- [Room Details API](#room-details-api)
|
- [Room Details API](#room-details-api)
|
||||||
- [Room Members API](#room-members-api)
|
- [Room Members API](#room-members-api)
|
||||||
|
- [Room State API](#room-state-api)
|
||||||
- [Delete Room API](#delete-room-api)
|
- [Delete Room API](#delete-room-api)
|
||||||
* [Parameters](#parameters-1)
|
* [Parameters](#parameters-1)
|
||||||
* [Response](#response)
|
* [Response](#response)
|
||||||
|
|
|
@ -42,17 +42,17 @@ To receive OpenTracing spans, start up a Jaeger server. This can be done
|
||||||
using docker like so:
|
using docker like so:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
docker run -d --name jaeger
|
docker run -d --name jaeger \
|
||||||
-p 6831:6831/udp \
|
-p 6831:6831/udp \
|
||||||
-p 6832:6832/udp \
|
-p 6832:6832/udp \
|
||||||
-p 5778:5778 \
|
-p 5778:5778 \
|
||||||
-p 16686:16686 \
|
-p 16686:16686 \
|
||||||
-p 14268:14268 \
|
-p 14268:14268 \
|
||||||
jaegertracing/all-in-one:1.13
|
jaegertracing/all-in-one:1
|
||||||
```
|
```
|
||||||
|
|
||||||
Latest documentation is probably at
|
Latest documentation is probably at
|
||||||
<https://www.jaegertracing.io/docs/1.13/getting-started/>
|
https://www.jaegertracing.io/docs/latest/getting-started.
|
||||||
|
|
||||||
## Enable OpenTracing in Synapse
|
## Enable OpenTracing in Synapse
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ as shown in the [sample config](./sample_config.yaml). For example:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
opentracing:
|
opentracing:
|
||||||
tracer_enabled: true
|
enabled: true
|
||||||
homeserver_whitelist:
|
homeserver_whitelist:
|
||||||
- "mytrustedhomeserver.org"
|
- "mytrustedhomeserver.org"
|
||||||
- "*.myotherhomeservers.com"
|
- "*.myotherhomeservers.com"
|
||||||
|
@ -90,4 +90,4 @@ to two problems, namely:
|
||||||
## Configuring Jaeger
|
## Configuring Jaeger
|
||||||
|
|
||||||
Sampling strategies can be set as in this document:
|
Sampling strategies can be set as in this document:
|
||||||
<https://www.jaegertracing.io/docs/1.13/sampling/>
|
<https://www.jaegertracing.io/docs/latest/sampling/>.
|
||||||
|
|
184
docs/postgres.md
184
docs/postgres.md
|
@ -1,6 +1,6 @@
|
||||||
# Using Postgres
|
# Using Postgres
|
||||||
|
|
||||||
Postgres version 9.5 or later is known to work.
|
Synapse supports PostgreSQL versions 9.6 or later.
|
||||||
|
|
||||||
## Install postgres client libraries
|
## Install postgres client libraries
|
||||||
|
|
||||||
|
@ -33,28 +33,15 @@ Assuming your PostgreSQL database user is called `postgres`, first authenticate
|
||||||
# Or, if your system uses sudo to get administrative rights
|
# Or, if your system uses sudo to get administrative rights
|
||||||
sudo -u postgres bash
|
sudo -u postgres bash
|
||||||
|
|
||||||
Then, create a user ``synapse_user`` with:
|
Then, create a postgres user and a database with:
|
||||||
|
|
||||||
|
# this will prompt for a password for the new user
|
||||||
createuser --pwprompt synapse_user
|
createuser --pwprompt synapse_user
|
||||||
|
|
||||||
Before you can authenticate with the `synapse_user`, you must create a
|
createdb --encoding=UTF8 --locale=C --template=template0 --owner=synapse_user synapse
|
||||||
database that it can access. To create a database, first connect to the
|
|
||||||
database with your database user:
|
|
||||||
|
|
||||||
su - postgres # Or: sudo -u postgres bash
|
The above will create a user called `synapse_user`, and a database called
|
||||||
psql
|
`synapse`.
|
||||||
|
|
||||||
and then run:
|
|
||||||
|
|
||||||
CREATE DATABASE synapse
|
|
||||||
ENCODING 'UTF8'
|
|
||||||
LC_COLLATE='C'
|
|
||||||
LC_CTYPE='C'
|
|
||||||
template=template0
|
|
||||||
OWNER synapse_user;
|
|
||||||
|
|
||||||
This would create an appropriate database named `synapse` owned by the
|
|
||||||
`synapse_user` user (which must already have been created as above).
|
|
||||||
|
|
||||||
Note that the PostgreSQL database *must* have the correct encoding set
|
Note that the PostgreSQL database *must* have the correct encoding set
|
||||||
(as shown above), otherwise it will not be able to store UTF8 strings.
|
(as shown above), otherwise it will not be able to store UTF8 strings.
|
||||||
|
@ -63,79 +50,6 @@ You may need to enable password authentication so `synapse_user` can
|
||||||
connect to the database. See
|
connect to the database. See
|
||||||
<https://www.postgresql.org/docs/current/auth-pg-hba-conf.html>.
|
<https://www.postgresql.org/docs/current/auth-pg-hba-conf.html>.
|
||||||
|
|
||||||
If you get an error along the lines of `FATAL: Ident authentication failed for
|
|
||||||
user "synapse_user"`, you may need to use an authentication method other than
|
|
||||||
`ident`:
|
|
||||||
|
|
||||||
* If the `synapse_user` user has a password, add the password to the `database:`
|
|
||||||
section of `homeserver.yaml`. Then add the following to `pg_hba.conf`:
|
|
||||||
|
|
||||||
```
|
|
||||||
host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that
|
|
||||||
```
|
|
||||||
|
|
||||||
* If the `synapse_user` user does not have a password, then a password doesn't
|
|
||||||
have to be added to `homeserver.yaml`. But the following does need to be added
|
|
||||||
to `pg_hba.conf`:
|
|
||||||
|
|
||||||
```
|
|
||||||
host synapse synapse_user ::1/128 trust
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that line order matters in `pg_hba.conf`, so make sure that if you do add a
|
|
||||||
new line, it is inserted before:
|
|
||||||
|
|
||||||
```
|
|
||||||
host all all ::1/128 ident
|
|
||||||
```
|
|
||||||
|
|
||||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
|
||||||
|
|
||||||
Synapse will refuse to set up a new database if it has the wrong values of
|
|
||||||
`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
|
|
||||||
different locales can cause issues if the locale library is updated from
|
|
||||||
underneath the database, or if a different version of the locale is used on any
|
|
||||||
replicas.
|
|
||||||
|
|
||||||
The safest way to fix the issue is to take a dump and recreate the database with
|
|
||||||
the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the
|
|
||||||
parameters on a live database and run a `REINDEX` on the entire database,
|
|
||||||
however extreme care must be taken to avoid database corruption.
|
|
||||||
|
|
||||||
Note that the above may fail with an error about duplicate rows if corruption
|
|
||||||
has already occurred, and such duplicate rows will need to be manually removed.
|
|
||||||
|
|
||||||
|
|
||||||
## Fixing inconsistent sequences error
|
|
||||||
|
|
||||||
Synapse uses Postgres sequences to generate IDs for various tables. A sequence
|
|
||||||
and associated table can get out of sync if, for example, Synapse has been
|
|
||||||
downgraded and then upgraded again.
|
|
||||||
|
|
||||||
To fix the issue shut down Synapse (including any and all workers) and run the
|
|
||||||
SQL command included in the error message. Once done Synapse should start
|
|
||||||
successfully.
|
|
||||||
|
|
||||||
|
|
||||||
## Tuning Postgres
|
|
||||||
|
|
||||||
The default settings should be fine for most deployments. For larger
|
|
||||||
scale deployments tuning some of the settings is recommended, details of
|
|
||||||
which can be found at
|
|
||||||
<https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server>.
|
|
||||||
|
|
||||||
In particular, we've found tuning the following values helpful for
|
|
||||||
performance:
|
|
||||||
|
|
||||||
- `shared_buffers`
|
|
||||||
- `effective_cache_size`
|
|
||||||
- `work_mem`
|
|
||||||
- `maintenance_work_mem`
|
|
||||||
- `autovacuum_work_mem`
|
|
||||||
|
|
||||||
Note that the appropriate values for those fields depend on the amount
|
|
||||||
of free memory the database host has available.
|
|
||||||
|
|
||||||
## Synapse config
|
## Synapse config
|
||||||
|
|
||||||
When you are ready to start using PostgreSQL, edit the `database`
|
When you are ready to start using PostgreSQL, edit the `database`
|
||||||
|
@ -165,6 +79,10 @@ may block for an extended period while it waits for a response from the
|
||||||
database server. Example values might be:
|
database server. Example values might be:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
database:
|
||||||
|
args:
|
||||||
|
# ... as above
|
||||||
|
|
||||||
# seconds of inactivity after which TCP should send a keepalive message to the server
|
# seconds of inactivity after which TCP should send a keepalive message to the server
|
||||||
keepalives_idle: 10
|
keepalives_idle: 10
|
||||||
|
|
||||||
|
@ -177,6 +95,26 @@ keepalives_interval: 10
|
||||||
keepalives_count: 3
|
keepalives_count: 3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Tuning Postgres
|
||||||
|
|
||||||
|
The default settings should be fine for most deployments. For larger
|
||||||
|
scale deployments tuning some of the settings is recommended, details of
|
||||||
|
which can be found at
|
||||||
|
<https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server>.
|
||||||
|
|
||||||
|
In particular, we've found tuning the following values helpful for
|
||||||
|
performance:
|
||||||
|
|
||||||
|
- `shared_buffers`
|
||||||
|
- `effective_cache_size`
|
||||||
|
- `work_mem`
|
||||||
|
- `maintenance_work_mem`
|
||||||
|
- `autovacuum_work_mem`
|
||||||
|
|
||||||
|
Note that the appropriate values for those fields depend on the amount
|
||||||
|
of free memory the database host has available.
|
||||||
|
|
||||||
|
|
||||||
## Porting from SQLite
|
## Porting from SQLite
|
||||||
|
|
||||||
### Overview
|
### Overview
|
||||||
|
@ -185,9 +123,8 @@ The script `synapse_port_db` allows porting an existing synapse server
|
||||||
backed by SQLite to using PostgreSQL. This is done in as a two phase
|
backed by SQLite to using PostgreSQL. This is done in as a two phase
|
||||||
process:
|
process:
|
||||||
|
|
||||||
1. Copy the existing SQLite database to a separate location (while the
|
1. Copy the existing SQLite database to a separate location and run
|
||||||
server is down) and running the port script against that offline
|
the port script against that offline database.
|
||||||
database.
|
|
||||||
2. Shut down the server. Rerun the port script to port any data that
|
2. Shut down the server. Rerun the port script to port any data that
|
||||||
has come in since taking the first snapshot. Restart server against
|
has come in since taking the first snapshot. Restart server against
|
||||||
the PostgreSQL database.
|
the PostgreSQL database.
|
||||||
|
@ -245,3 +182,60 @@ PostgreSQL database configuration file `homeserver-postgres.yaml`:
|
||||||
./synctl start
|
./synctl start
|
||||||
|
|
||||||
Synapse should now be running against PostgreSQL.
|
Synapse should now be running against PostgreSQL.
|
||||||
|
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Alternative auth methods
|
||||||
|
|
||||||
|
If you get an error along the lines of `FATAL: Ident authentication failed for
|
||||||
|
user "synapse_user"`, you may need to use an authentication method other than
|
||||||
|
`ident`:
|
||||||
|
|
||||||
|
* If the `synapse_user` user has a password, add the password to the `database:`
|
||||||
|
section of `homeserver.yaml`. Then add the following to `pg_hba.conf`:
|
||||||
|
|
||||||
|
```
|
||||||
|
host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that
|
||||||
|
```
|
||||||
|
|
||||||
|
* If the `synapse_user` user does not have a password, then a password doesn't
|
||||||
|
have to be added to `homeserver.yaml`. But the following does need to be added
|
||||||
|
to `pg_hba.conf`:
|
||||||
|
|
||||||
|
```
|
||||||
|
host synapse synapse_user ::1/128 trust
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that line order matters in `pg_hba.conf`, so make sure that if you do add a
|
||||||
|
new line, it is inserted before:
|
||||||
|
|
||||||
|
```
|
||||||
|
host all all ::1/128 ident
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||||
|
|
||||||
|
Synapse will refuse to set up a new database if it has the wrong values of
|
||||||
|
`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
|
||||||
|
different locales can cause issues if the locale library is updated from
|
||||||
|
underneath the database, or if a different version of the locale is used on any
|
||||||
|
replicas.
|
||||||
|
|
||||||
|
The safest way to fix the issue is to dump the database and recreate it with
|
||||||
|
the correct locale parameter (as shown above). It is also possible to change the
|
||||||
|
parameters on a live database and run a `REINDEX` on the entire database,
|
||||||
|
however extreme care must be taken to avoid database corruption.
|
||||||
|
|
||||||
|
Note that the above may fail with an error about duplicate rows if corruption
|
||||||
|
has already occurred, and such duplicate rows will need to be manually removed.
|
||||||
|
|
||||||
|
### Fixing inconsistent sequences error
|
||||||
|
|
||||||
|
Synapse uses Postgres sequences to generate IDs for various tables. A sequence
|
||||||
|
and associated table can get out of sync if, for example, Synapse has been
|
||||||
|
downgraded and then upgraded again.
|
||||||
|
|
||||||
|
To fix the issue shut down Synapse (including any and all workers) and run the
|
||||||
|
SQL command included in the error message. Once done Synapse should start
|
||||||
|
successfully.
|
||||||
|
|
|
@ -28,7 +28,11 @@ async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None
|
||||||
which can be given a list of local or remote MXIDs to broadcast known, online user
|
which can be given a list of local or remote MXIDs to broadcast known, online user
|
||||||
presence to (for those users that the receiving user is considered interested in).
|
presence to (for those users that the receiving user is considered interested in).
|
||||||
It does not include state for users who are currently offline, and it can only be
|
It does not include state for users who are currently offline, and it can only be
|
||||||
called on workers that support sending federation.
|
called on workers that support sending federation. Additionally, this method must
|
||||||
|
only be called from the process that has been configured to write to the
|
||||||
|
the [presence stream](https://github.com/matrix-org/synapse/blob/master/docs/workers.md#stream-writers).
|
||||||
|
By default, this is the main process, but another worker can be configured to do
|
||||||
|
so.
|
||||||
|
|
||||||
### Module structure
|
### Module structure
|
||||||
|
|
||||||
|
|
|
@ -683,33 +683,6 @@ acme:
|
||||||
#
|
#
|
||||||
account_key_file: DATADIR/acme_account.key
|
account_key_file: DATADIR/acme_account.key
|
||||||
|
|
||||||
# List of allowed TLS fingerprints for this server to publish along
|
|
||||||
# with the signing keys for this server. Other matrix servers that
|
|
||||||
# make HTTPS requests to this server will check that the TLS
|
|
||||||
# certificates returned by this server match one of the fingerprints.
|
|
||||||
#
|
|
||||||
# Synapse automatically adds the fingerprint of its own certificate
|
|
||||||
# to the list. So if federation traffic is handled directly by synapse
|
|
||||||
# then no modification to the list is required.
|
|
||||||
#
|
|
||||||
# If synapse is run behind a load balancer that handles the TLS then it
|
|
||||||
# will be necessary to add the fingerprints of the certificates used by
|
|
||||||
# the loadbalancers to this list if they are different to the one
|
|
||||||
# synapse is using.
|
|
||||||
#
|
|
||||||
# Homeservers are permitted to cache the list of TLS fingerprints
|
|
||||||
# returned in the key responses up to the "valid_until_ts" returned in
|
|
||||||
# key. It may be necessary to publish the fingerprints of a new
|
|
||||||
# certificate and wait until the "valid_until_ts" of the previous key
|
|
||||||
# responses have passed before deploying it.
|
|
||||||
#
|
|
||||||
# You can calculate a fingerprint from a given TLS listener via:
|
|
||||||
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
|
||||||
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
|
||||||
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
|
||||||
#
|
|
||||||
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
|
||||||
|
|
||||||
|
|
||||||
## Federation ##
|
## Federation ##
|
||||||
|
|
||||||
|
@ -2845,7 +2818,8 @@ opentracing:
|
||||||
#enabled: true
|
#enabled: true
|
||||||
|
|
||||||
# The list of homeservers we wish to send and receive span contexts and span baggage.
|
# The list of homeservers we wish to send and receive span contexts and span baggage.
|
||||||
# See docs/opentracing.rst
|
# See docs/opentracing.rst.
|
||||||
|
#
|
||||||
# This is a list of regexes which are matched against the server_name of the
|
# This is a list of regexes which are matched against the server_name of the
|
||||||
# homeserver.
|
# homeserver.
|
||||||
#
|
#
|
||||||
|
@ -2854,19 +2828,26 @@ opentracing:
|
||||||
#homeserver_whitelist:
|
#homeserver_whitelist:
|
||||||
# - ".*"
|
# - ".*"
|
||||||
|
|
||||||
|
# A list of the matrix IDs of users whose requests will always be traced,
|
||||||
|
# even if the tracing system would otherwise drop the traces due to
|
||||||
|
# probabilistic sampling.
|
||||||
|
#
|
||||||
|
# By default, the list is empty.
|
||||||
|
#
|
||||||
|
#force_tracing_for_users:
|
||||||
|
# - "@user1:server_name"
|
||||||
|
# - "@user2:server_name"
|
||||||
|
|
||||||
# Jaeger can be configured to sample traces at different rates.
|
# Jaeger can be configured to sample traces at different rates.
|
||||||
# All configuration options provided by Jaeger can be set here.
|
# All configuration options provided by Jaeger can be set here.
|
||||||
# Jaeger's configuration mostly related to trace sampling which
|
# Jaeger's configuration is mostly related to trace sampling which
|
||||||
# is documented here:
|
# is documented here:
|
||||||
# https://www.jaegertracing.io/docs/1.13/sampling/.
|
# https://www.jaegertracing.io/docs/latest/sampling/.
|
||||||
#
|
#
|
||||||
#jaeger_config:
|
#jaeger_config:
|
||||||
# sampler:
|
# sampler:
|
||||||
# type: const
|
# type: const
|
||||||
# param: 1
|
# param: 1
|
||||||
|
|
||||||
# Logging whether spans were started and reported
|
|
||||||
#
|
|
||||||
# logging:
|
# logging:
|
||||||
# false
|
# false
|
||||||
|
|
||||||
|
@ -2935,3 +2916,18 @@ redis:
|
||||||
# Optional password if configured on the Redis instance
|
# Optional password if configured on the Redis instance
|
||||||
#
|
#
|
||||||
#password: <secret_password>
|
#password: <secret_password>
|
||||||
|
|
||||||
|
|
||||||
|
# Enable experimental features in Synapse.
|
||||||
|
#
|
||||||
|
# Experimental features might break or be removed without a deprecation
|
||||||
|
# period.
|
||||||
|
#
|
||||||
|
experimental_features:
|
||||||
|
# Support for Spaces (MSC1772), it enables the following:
|
||||||
|
#
|
||||||
|
# * The Spaces Summary API (MSC2946).
|
||||||
|
# * Restricting room membership based on space membership (MSC3083).
|
||||||
|
#
|
||||||
|
# Uncomment to disable support for Spaces.
|
||||||
|
#spaces_enabled: false
|
||||||
|
|
|
@ -67,8 +67,8 @@ A custom mapping provider must specify the following methods:
|
||||||
- Arguments:
|
- Arguments:
|
||||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||||
information from.
|
information from.
|
||||||
- This method must return a string, which is the unique identifier for the
|
- This method must return a string, which is the unique, immutable identifier
|
||||||
user. Commonly the ``sub`` claim of the response.
|
for the user. Commonly the `sub` claim of the response.
|
||||||
* `map_user_attributes(self, userinfo, token, failures)`
|
* `map_user_attributes(self, userinfo, token, failures)`
|
||||||
- This method must be async.
|
- This method must be async.
|
||||||
- Arguments:
|
- Arguments:
|
||||||
|
@ -87,7 +87,9 @@ A custom mapping provider must specify the following methods:
|
||||||
`localpart` value, such as `john.doe1`.
|
`localpart` value, such as `john.doe1`.
|
||||||
- Returns a dictionary with two keys:
|
- Returns a dictionary with two keys:
|
||||||
- `localpart`: A string, used to generate the Matrix ID. If this is
|
- `localpart`: A string, used to generate the Matrix ID. If this is
|
||||||
`None`, the user is prompted to pick their own username.
|
`None`, the user is prompted to pick their own username. This is only used
|
||||||
|
during a user's first login. Once a localpart has been associated with a
|
||||||
|
remote user ID (see `get_remote_user_id`) it cannot be updated.
|
||||||
- `displayname`: An optional string, the display name for the user.
|
- `displayname`: An optional string, the display name for the user.
|
||||||
* `get_extra_attributes(self, userinfo, token)`
|
* `get_extra_attributes(self, userinfo, token)`
|
||||||
- This method must be async.
|
- This method must be async.
|
||||||
|
@ -153,8 +155,8 @@ A custom mapping provider must specify the following methods:
|
||||||
information from.
|
information from.
|
||||||
- `client_redirect_url` - A string, the URL that the client will be
|
- `client_redirect_url` - A string, the URL that the client will be
|
||||||
redirected to.
|
redirected to.
|
||||||
- This method must return a string, which is the unique identifier for the
|
- This method must return a string, which is the unique, immutable identifier
|
||||||
user. Commonly the ``uid`` claim of the response.
|
for the user. Commonly the `uid` claim of the response.
|
||||||
* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
|
* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
|
||||||
- Arguments:
|
- Arguments:
|
||||||
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
||||||
|
@ -172,8 +174,10 @@ A custom mapping provider must specify the following methods:
|
||||||
redirected to.
|
redirected to.
|
||||||
- This method must return a dictionary, which will then be used by Synapse
|
- This method must return a dictionary, which will then be used by Synapse
|
||||||
to build a new user. The following keys are allowed:
|
to build a new user. The following keys are allowed:
|
||||||
* `mxid_localpart` - The mxid localpart of the new user. If this is
|
* `mxid_localpart` - A string, the mxid localpart of the new user. If this is
|
||||||
`None`, the user is prompted to pick their own username.
|
`None`, the user is prompted to pick their own username. This is only used
|
||||||
|
during a user's first login. Once a localpart has been associated with a
|
||||||
|
remote user ID (see `get_remote_user_id`) it cannot be updated.
|
||||||
* `displayname` - The displayname of the new user. If not provided, will default to
|
* `displayname` - The displayname of the new user. If not provided, will default to
|
||||||
the value of `mxid_localpart`.
|
the value of `mxid_localpart`.
|
||||||
* `emails` - A list of emails for the new user. If not provided, will
|
* `emails` - A list of emails for the new user. If not provided, will
|
||||||
|
|
|
@ -65,3 +65,33 @@ systemctl restart matrix-synapse-worker@federation_reader.service
|
||||||
systemctl enable matrix-synapse-worker@federation_writer.service
|
systemctl enable matrix-synapse-worker@federation_writer.service
|
||||||
systemctl restart matrix-synapse.target
|
systemctl restart matrix-synapse.target
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Hardening
|
||||||
|
|
||||||
|
**Optional:** If further hardening is desired, the file
|
||||||
|
`override-hardened.conf` may be copied from
|
||||||
|
`contrib/systemd/override-hardened.conf` in this repository to the location
|
||||||
|
`/etc/systemd/system/matrix-synapse.service.d/override-hardened.conf` (the
|
||||||
|
directory may have to be created). It enables certain sandboxing features in
|
||||||
|
systemd to further secure the synapse service. You may read the comments to
|
||||||
|
understand what the override file is doing. The same file will need to be copied
|
||||||
|
to
|
||||||
|
`/etc/systemd/system/matrix-synapse-worker@.service.d/override-hardened-worker.conf`
|
||||||
|
(this directory may also have to be created) in order to apply the same
|
||||||
|
hardening options to any worker processes.
|
||||||
|
|
||||||
|
Once these files have been copied to their appropriate locations, simply reload
|
||||||
|
systemd's manager config files and restart all Synapse services to apply the hardening options. They will automatically
|
||||||
|
be applied at every restart as long as the override files are present at the
|
||||||
|
specified locations.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
systemctl daemon-reload
|
||||||
|
|
||||||
|
# Restart services
|
||||||
|
systemctl restart matrix-synapse.target
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to see their effect, you may run `systemd-analyze security
|
||||||
|
matrix-synapse.service` before and after applying the hardening options to see
|
||||||
|
the changes being applied at a glance.
|
||||||
|
|
|
@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server.
|
||||||
|
|
||||||
The directory info is stored in various tables, which can (typically after
|
The directory info is stored in various tables, which can (typically after
|
||||||
DB corruption) get stale or out of sync. If this happens, for now the
|
DB corruption) get stale or out of sync. If this happens, for now the
|
||||||
solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql)
|
solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql)
|
||||||
and then restart synapse. This should then start a background task to
|
and then restart synapse. This should then start a background task to
|
||||||
flush the current tables and regenerate the directory.
|
flush the current tables and regenerate the directory.
|
||||||
|
|
12
mypy.ini
12
mypy.ini
|
@ -71,8 +71,13 @@ files =
|
||||||
synapse/types.py,
|
synapse/types.py,
|
||||||
synapse/util/async_helpers.py,
|
synapse/util/async_helpers.py,
|
||||||
synapse/util/caches,
|
synapse/util/caches,
|
||||||
|
synapse/util/daemonize.py,
|
||||||
|
synapse/util/hash.py,
|
||||||
|
synapse/util/iterutils.py,
|
||||||
synapse/util/metrics.py,
|
synapse/util/metrics.py,
|
||||||
synapse/util/macaroons.py,
|
synapse/util/macaroons.py,
|
||||||
|
synapse/util/module_loader.py,
|
||||||
|
synapse/util/msisdn.py,
|
||||||
synapse/util/stringutils.py,
|
synapse/util/stringutils.py,
|
||||||
synapse/visibility.py,
|
synapse/visibility.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
|
@ -80,6 +85,7 @@ files =
|
||||||
tests/handlers/test_password_providers.py,
|
tests/handlers/test_password_providers.py,
|
||||||
tests/rest/client/v1/test_login.py,
|
tests/rest/client/v1/test_login.py,
|
||||||
tests/rest/client/v2_alpha/test_auth.py,
|
tests/rest/client/v2_alpha/test_auth.py,
|
||||||
|
tests/util/test_itertools.py,
|
||||||
tests/util/test_stream_change_cache.py
|
tests/util/test_stream_change_cache.py
|
||||||
|
|
||||||
[mypy-pymacaroons.*]
|
[mypy-pymacaroons.*]
|
||||||
|
@ -174,3 +180,9 @@ ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-pympler.*]
|
[mypy-pympler.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-phonenumbers.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-ijson.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
|
@ -27,12 +27,12 @@ DISTS = (
|
||||||
"ubuntu:hirsute", # 21.04 (EOL 2022-01-05)
|
"ubuntu:hirsute", # 21.04 (EOL 2022-01-05)
|
||||||
)
|
)
|
||||||
|
|
||||||
DESC = '''\
|
DESC = """\
|
||||||
Builds .debs for synapse, using a Docker image for the build environment.
|
Builds .debs for synapse, using a Docker image for the build environment.
|
||||||
|
|
||||||
By default, builds for all known distributions, but a list of distributions
|
By default, builds for all known distributions, but a list of distributions
|
||||||
can be passed on the commandline for debugging.
|
can be passed on the commandline for debugging.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Builder(object):
|
class Builder(object):
|
||||||
|
@ -68,7 +68,7 @@ class Builder(object):
|
||||||
# we tend to get source packages which are full of debs. (We could hack
|
# we tend to get source packages which are full of debs. (We could hack
|
||||||
# around that with more magic in the build_debian.sh script, but that
|
# around that with more magic in the build_debian.sh script, but that
|
||||||
# doesn't solve the problem for natively-run dpkg-buildpakage).
|
# doesn't solve the problem for natively-run dpkg-buildpakage).
|
||||||
debsdir = os.path.join(projdir, '../debs')
|
debsdir = os.path.join(projdir, "../debs")
|
||||||
os.makedirs(debsdir, exist_ok=True)
|
os.makedirs(debsdir, exist_ok=True)
|
||||||
|
|
||||||
if self.redirect_stdout:
|
if self.redirect_stdout:
|
||||||
|
@ -79,30 +79,47 @@ class Builder(object):
|
||||||
stdout = None
|
stdout = None
|
||||||
|
|
||||||
# first build a docker image for the build environment
|
# first build a docker image for the build environment
|
||||||
subprocess.check_call([
|
subprocess.check_call(
|
||||||
"docker", "build",
|
[
|
||||||
"--tag", "dh-venv-builder:" + tag,
|
|
||||||
"--build-arg", "distro=" + dist,
|
|
||||||
"-f", "docker/Dockerfile-dhvirtualenv",
|
|
||||||
"docker",
|
"docker",
|
||||||
], stdout=stdout, stderr=subprocess.STDOUT)
|
"build",
|
||||||
|
"--tag",
|
||||||
|
"dh-venv-builder:" + tag,
|
||||||
|
"--build-arg",
|
||||||
|
"distro=" + dist,
|
||||||
|
"-f",
|
||||||
|
"docker/Dockerfile-dhvirtualenv",
|
||||||
|
"docker",
|
||||||
|
],
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
|
||||||
container_name = "synapse_build_" + tag
|
container_name = "synapse_build_" + tag
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.active_containers.add(container_name)
|
self.active_containers.add(container_name)
|
||||||
|
|
||||||
# then run the build itself
|
# then run the build itself
|
||||||
subprocess.check_call([
|
subprocess.check_call(
|
||||||
"docker", "run",
|
[
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
"--rm",
|
"--rm",
|
||||||
"--name", container_name,
|
"--name",
|
||||||
|
container_name,
|
||||||
"--volume=" + projdir + ":/synapse/source:ro",
|
"--volume=" + projdir + ":/synapse/source:ro",
|
||||||
"--volume=" + debsdir + ":/debs",
|
"--volume=" + debsdir + ":/debs",
|
||||||
"-e", "TARGET_USERID=%i" % (os.getuid(), ),
|
"-e",
|
||||||
"-e", "TARGET_GROUPID=%i" % (os.getgid(), ),
|
"TARGET_USERID=%i" % (os.getuid(),),
|
||||||
"-e", "DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""),
|
"-e",
|
||||||
|
"TARGET_GROUPID=%i" % (os.getgid(),),
|
||||||
|
"-e",
|
||||||
|
"DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""),
|
||||||
"dh-venv-builder:" + tag,
|
"dh-venv-builder:" + tag,
|
||||||
], stdout=stdout, stderr=subprocess.STDOUT)
|
],
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.active_containers.remove(container_name)
|
self.active_containers.remove(container_name)
|
||||||
|
@ -117,9 +134,14 @@ class Builder(object):
|
||||||
|
|
||||||
for c in active:
|
for c in active:
|
||||||
print("killing container %s" % (c,))
|
print("killing container %s" % (c,))
|
||||||
subprocess.run([
|
subprocess.run(
|
||||||
"docker", "kill", c,
|
[
|
||||||
], stdout=subprocess.DEVNULL)
|
"docker",
|
||||||
|
"kill",
|
||||||
|
c,
|
||||||
|
],
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.active_containers.remove(c)
|
self.active_containers.remove(c)
|
||||||
|
|
||||||
|
@ -130,31 +152,38 @@ def run_builds(dists, jobs=1, skip_tests=False):
|
||||||
def sig(signum, _frame):
|
def sig(signum, _frame):
|
||||||
print("Caught SIGINT")
|
print("Caught SIGINT")
|
||||||
builder.kill_containers()
|
builder.kill_containers()
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sig)
|
signal.signal(signal.SIGINT, sig)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=jobs) as e:
|
with ThreadPoolExecutor(max_workers=jobs) as e:
|
||||||
res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists)
|
res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists)
|
||||||
|
|
||||||
# make sure we consume the iterable so that exceptions are raised.
|
# make sure we consume the iterable so that exceptions are raised.
|
||||||
for r in res:
|
for _ in res:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=DESC,
|
description=DESC,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-j', '--jobs', type=int, default=1,
|
"-j",
|
||||||
help='specify the number of builds to run in parallel',
|
"--jobs",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="specify the number of builds to run in parallel",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--no-check', action='store_true',
|
"--no-check",
|
||||||
help='skip running tests after building',
|
action="store_true",
|
||||||
|
help="skip running tests after building",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'dist', nargs='*', default=DISTS,
|
"dist",
|
||||||
help='a list of distributions to build for. Default: %(default)s',
|
nargs="*",
|
||||||
|
default=DISTS,
|
||||||
|
help="a list of distributions to build for. Default: %(default)s",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_builds(dists=args.dist, jobs=args.jobs, skip_tests=args.no_check)
|
run_builds(dists=args.dist, jobs=args.jobs, skip_tests=args.no_check)
|
||||||
|
|
|
@ -10,6 +10,9 @@
|
||||||
# checkout by setting the COMPLEMENT_DIR environment variable to the
|
# checkout by setting the COMPLEMENT_DIR environment variable to the
|
||||||
# filepath of a local Complement checkout.
|
# filepath of a local Complement checkout.
|
||||||
#
|
#
|
||||||
|
# By default Synapse is run in monolith mode. This can be overridden by
|
||||||
|
# setting the WORKERS environment variable.
|
||||||
|
#
|
||||||
# A regular expression of test method names can be supplied as the first
|
# A regular expression of test method names can be supplied as the first
|
||||||
# argument to the script. Complement will then only run those tests. If
|
# argument to the script. Complement will then only run those tests. If
|
||||||
# no regex is supplied, all tests are run. For example;
|
# no regex is supplied, all tests are run. For example;
|
||||||
|
@ -32,10 +35,26 @@ if [[ -z "$COMPLEMENT_DIR" ]]; then
|
||||||
echo "Checkout available at 'complement-master'"
|
echo "Checkout available at 'complement-master'"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# If we're using workers, modify the docker files slightly.
|
||||||
|
if [[ -n "$WORKERS" ]]; then
|
||||||
|
BASE_IMAGE=matrixdotorg/synapse-workers
|
||||||
|
BASE_DOCKERFILE=docker/Dockerfile-workers
|
||||||
|
export COMPLEMENT_BASE_IMAGE=complement-synapse-workers
|
||||||
|
COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile
|
||||||
|
# And provide some more configuration to complement.
|
||||||
|
export COMPLEMENT_CA=true
|
||||||
|
export COMPLEMENT_VERSION_CHECK_ITERATIONS=500
|
||||||
|
else
|
||||||
|
BASE_IMAGE=matrixdotorg/synapse
|
||||||
|
BASE_DOCKERFILE=docker/Dockerfile
|
||||||
|
export COMPLEMENT_BASE_IMAGE=complement-synapse
|
||||||
|
COMPLEMENT_DOCKERFILE=Synapse.Dockerfile
|
||||||
|
fi
|
||||||
|
|
||||||
# Build the base Synapse image from the local checkout
|
# Build the base Synapse image from the local checkout
|
||||||
docker build -t matrixdotorg/synapse -f docker/Dockerfile .
|
docker build -t $BASE_IMAGE -f "$BASE_DOCKERFILE" .
|
||||||
# Build the Synapse monolith image from Complement, based on the above image we just built
|
# Build the Synapse monolith image from Complement, based on the above image we just built
|
||||||
docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
|
docker build -t $COMPLEMENT_BASE_IMAGE -f "$COMPLEMENT_DIR/dockerfiles/$COMPLEMENT_DOCKERFILE" "$COMPLEMENT_DIR/dockerfiles"
|
||||||
|
|
||||||
cd "$COMPLEMENT_DIR"
|
cd "$COMPLEMENT_DIR"
|
||||||
|
|
||||||
|
@ -46,4 +65,4 @@ if [[ -n "$1" ]]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Run the tests!
|
# Run the tests!
|
||||||
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist,msc2946,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
|
go test -v -tags synapse_blacklist,msc2946,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import hashlib
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -54,15 +53,9 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate):
|
||||||
"server_name": server_name,
|
"server_name": server_name,
|
||||||
"verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
|
"verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
|
||||||
"valid_until_ts": valid_until,
|
"valid_until_ts": valid_until,
|
||||||
"tls_fingerprints": [fingerprint(certificate)],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def fingerprint(certificate):
|
|
||||||
finger = hashlib.sha256(certificate)
|
|
||||||
return {"sha256": encode_base64(finger.digest())}
|
|
||||||
|
|
||||||
|
|
||||||
def rows_v2(server, json):
|
def rows_v2(server, json):
|
||||||
valid_until = json["valid_until_ts"]
|
valid_until = json["valid_until_ts"]
|
||||||
key_json = encode_canonical_json(json)
|
key_json = encode_canonical_json(json)
|
||||||
|
|
|
@ -80,8 +80,22 @@ else
|
||||||
# then lint everything!
|
# then lint everything!
|
||||||
if [[ -z ${files+x} ]]; then
|
if [[ -z ${files+x} ]]; then
|
||||||
# Lint all source code files and directories
|
# Lint all source code files and directories
|
||||||
# Note: this list aims the mirror the one in tox.ini
|
# Note: this list aims to mirror the one in tox.ini
|
||||||
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
|
files=(
|
||||||
|
"synapse" "docker" "tests"
|
||||||
|
# annoyingly, black doesn't find these so we have to list them
|
||||||
|
"scripts/export_signing_key"
|
||||||
|
"scripts/generate_config"
|
||||||
|
"scripts/generate_log_config"
|
||||||
|
"scripts/hash_password"
|
||||||
|
"scripts/register_new_matrix_user"
|
||||||
|
"scripts/synapse_port_db"
|
||||||
|
"scripts-dev"
|
||||||
|
"scripts-dev/build_debian_packages"
|
||||||
|
"scripts-dev/sign_json"
|
||||||
|
"scripts-dev/update_database"
|
||||||
|
"contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite"
|
||||||
|
)
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,11 @@ def exit(status: int = 0, message: Optional[str] = None):
|
||||||
def format_plain(public_key: nacl.signing.VerifyKey):
|
def format_plain(public_key: nacl.signing.VerifyKey):
|
||||||
print(
|
print(
|
||||||
"%s:%s %s"
|
"%s:%s %s"
|
||||||
% (public_key.alg, public_key.version, encode_verify_key_base64(public_key),)
|
% (
|
||||||
|
public_key.alg,
|
||||||
|
public_key.version,
|
||||||
|
encode_verify_key_base64(public_key),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,7 +54,10 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read",
|
"key_file",
|
||||||
|
nargs="+",
|
||||||
|
type=argparse.FileType("r"),
|
||||||
|
help="The key file to read",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -11,7 +11,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config-dir",
|
"--config-dir",
|
||||||
default="CONFDIR",
|
default="CONFDIR",
|
||||||
|
|
||||||
help="The path where the config files are kept. Used to create filenames for "
|
help="The path where the config files are kept. Used to create filenames for "
|
||||||
"things like the log config and the signing key. Default: %(default)s",
|
"things like the log config and the signing key. Default: %(default)s",
|
||||||
)
|
)
|
||||||
|
@ -41,19 +40,20 @@ if __name__ == "__main__":
|
||||||
"--generate-secrets",
|
"--generate-secrets",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable generation of new secrets for things like the macaroon_secret_key."
|
help="Enable generation of new secrets for things like the macaroon_secret_key."
|
||||||
"By default, these parameters will be left unset."
|
"By default, these parameters will be left unset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-o", "--output-file",
|
"-o",
|
||||||
type=argparse.FileType('w'),
|
"--output-file",
|
||||||
|
type=argparse.FileType("w"),
|
||||||
default=sys.stdout,
|
default=sys.stdout,
|
||||||
help="File to write the configuration to. Default: stdout",
|
help="File to write the configuration to. Default: stdout",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--header-file",
|
"--header-file",
|
||||||
type=argparse.FileType('r'),
|
type=argparse.FileType("r"),
|
||||||
help="File from which to read a header, which will be printed before the "
|
help="File from which to read a header, which will be printed before the "
|
||||||
"generated config.",
|
"generated config.",
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-c",
|
"-c",
|
||||||
"--config",
|
"--config",
|
||||||
type=argparse.FileType('r'),
|
type=argparse.FileType("r"),
|
||||||
help=(
|
help=(
|
||||||
"Path to server config file. "
|
"Path to server config file. "
|
||||||
"Used to read in bcrypt_rounds and password_pepper."
|
"Used to read in bcrypt_rounds and password_pepper."
|
||||||
|
@ -72,8 +72,8 @@ if __name__ == "__main__":
|
||||||
pw = unicodedata.normalize("NFKC", password)
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
hashed = bcrypt.hashpw(
|
hashed = bcrypt.hashpw(
|
||||||
pw.encode('utf8') + password_pepper.encode("utf8"),
|
pw.encode("utf8") + password_pepper.encode("utf8"),
|
||||||
bcrypt.gensalt(bcrypt_rounds),
|
bcrypt.gensalt(bcrypt_rounds),
|
||||||
).decode('ascii')
|
).decode("ascii")
|
||||||
|
|
||||||
print(hashed)
|
print(hashed)
|
||||||
|
|
|
@ -294,8 +294,7 @@ class Porter(object):
|
||||||
return table, already_ported, total_to_port, forward_chunk, backward_chunk
|
return table, already_ported, total_to_port, forward_chunk, backward_chunk
|
||||||
|
|
||||||
async def get_table_constraints(self) -> Dict[str, Set[str]]:
|
async def get_table_constraints(self) -> Dict[str, Set[str]]:
|
||||||
"""Returns a map of tables that have foreign key constraints to tables they depend on.
|
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_constraints(txn):
|
def _get_constraints(txn):
|
||||||
# We can pull the information about foreign key constraints out from
|
# We can pull the information about foreign key constraints out from
|
||||||
|
@ -504,7 +503,9 @@ class Porter(object):
|
||||||
return
|
return
|
||||||
|
|
||||||
def build_db_store(
|
def build_db_store(
|
||||||
self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
|
self,
|
||||||
|
db_config: DatabaseConnectionConfig,
|
||||||
|
allow_outdated_version: bool = False,
|
||||||
):
|
):
|
||||||
"""Builds and returns a database store using the provided configuration.
|
"""Builds and returns a database store using the provided configuration.
|
||||||
|
|
||||||
|
@ -740,7 +741,7 @@ class Porter(object):
|
||||||
return col
|
return col
|
||||||
|
|
||||||
outrows = []
|
outrows = []
|
||||||
for i, row in enumerate(rows):
|
for row in rows:
|
||||||
try:
|
try:
|
||||||
outrows.append(
|
outrows.append(
|
||||||
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
|
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
|
||||||
|
@ -890,8 +891,7 @@ class Porter(object):
|
||||||
await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
||||||
|
|
||||||
async def _setup_events_stream_seqs(self) -> None:
|
async def _setup_events_stream_seqs(self) -> None:
|
||||||
"""Set the event stream sequences to the correct values.
|
"""Set the event stream sequences to the correct values."""
|
||||||
"""
|
|
||||||
|
|
||||||
# We get called before we've ported the events table, so we need to
|
# We get called before we've ported the events table, so we need to
|
||||||
# fetch the current positions from the SQLite store.
|
# fetch the current positions from the SQLite store.
|
||||||
|
@ -920,12 +920,14 @@ class Porter(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.postgres_store.db_pool.runInteraction(
|
await self.postgres_store.db_pool.runInteraction(
|
||||||
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
|
"_setup_events_stream_seqs",
|
||||||
|
_setup_events_stream_seqs_set_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
|
async def _setup_sequence(
|
||||||
"""Set a sequence to the correct value.
|
self, sequence_name: str, stream_id_tables: Iterable[str]
|
||||||
"""
|
) -> None:
|
||||||
|
"""Set a sequence to the correct value."""
|
||||||
current_stream_ids = []
|
current_stream_ids = []
|
||||||
for stream_id_table in stream_id_tables:
|
for stream_id_table in stream_id_tables:
|
||||||
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
|
@ -942,17 +944,22 @@ class Porter(object):
|
||||||
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
|
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
|
||||||
txn.execute(sql + " %s", (next_id,))
|
txn.execute(sql + " %s", (next_id,))
|
||||||
|
|
||||||
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
|
await self.postgres_store.db_pool.runInteraction(
|
||||||
|
"_setup_%s" % (sequence_name,), r
|
||||||
|
)
|
||||||
|
|
||||||
async def _setup_auth_chain_sequence(self) -> None:
|
async def _setup_auth_chain_sequence(self) -> None:
|
||||||
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
|
table="event_auth_chains",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="MAX(chain_id)",
|
||||||
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def r(txn):
|
def r(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
||||||
(curr_chain_id,),
|
(curr_chain_id + 1,),
|
||||||
)
|
)
|
||||||
|
|
||||||
if curr_chain_id is not None:
|
if curr_chain_id is not None:
|
||||||
|
@ -968,8 +975,7 @@ class Porter(object):
|
||||||
|
|
||||||
|
|
||||||
class Progress(object):
|
class Progress(object):
|
||||||
"""Used to report progress of the port
|
"""Used to report progress of the port"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tables = {}
|
self.tables = {}
|
||||||
|
@ -994,8 +1000,7 @@ class Progress(object):
|
||||||
|
|
||||||
|
|
||||||
class CursesProgress(Progress):
|
class CursesProgress(Progress):
|
||||||
"""Reports progress to a curses window
|
"""Reports progress to a curses window"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, stdscr):
|
def __init__(self, stdscr):
|
||||||
self.stdscr = stdscr
|
self.stdscr = stdscr
|
||||||
|
@ -1020,7 +1025,7 @@ class CursesProgress(Progress):
|
||||||
|
|
||||||
self.total_processed = 0
|
self.total_processed = 0
|
||||||
self.total_remaining = 0
|
self.total_remaining = 0
|
||||||
for table, data in self.tables.items():
|
for data in self.tables.values():
|
||||||
self.total_processed += data["num_done"] - data["start"]
|
self.total_processed += data["num_done"] - data["start"]
|
||||||
self.total_remaining += data["total"] - data["num_done"]
|
self.total_remaining += data["total"] - data["num_done"]
|
||||||
|
|
||||||
|
@ -1111,8 +1116,7 @@ class CursesProgress(Progress):
|
||||||
|
|
||||||
|
|
||||||
class TerminalProgress(Progress):
|
class TerminalProgress(Progress):
|
||||||
"""Just prints progress to the terminal
|
"""Just prints progress to the terminal"""
|
||||||
"""
|
|
||||||
|
|
||||||
def update(self, table, num_done):
|
def update(self, table, num_done):
|
||||||
super(TerminalProgress, self).update(table, num_done)
|
super(TerminalProgress, self).update(table, num_done)
|
||||||
|
|
|
@ -47,7 +47,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.34.0"
|
__version__ = "1.35.0rc1"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -87,6 +87,7 @@ class Auth:
|
||||||
)
|
)
|
||||||
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
|
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||||
|
|
||||||
async def check_from_context(
|
async def check_from_context(
|
||||||
self, room_version: str, event, context, do_sig_check=True
|
self, room_version: str, event, context, do_sig_check=True
|
||||||
|
@ -208,6 +209,8 @@ class Auth:
|
||||||
opentracing.set_tag("authenticated_entity", user_id)
|
opentracing.set_tag("authenticated_entity", user_id)
|
||||||
opentracing.set_tag("user_id", user_id)
|
opentracing.set_tag("user_id", user_id)
|
||||||
opentracing.set_tag("appservice_id", app_service.id)
|
opentracing.set_tag("appservice_id", app_service.id)
|
||||||
|
if user_id in self._force_tracing_for_users:
|
||||||
|
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
||||||
|
|
||||||
return requester
|
return requester
|
||||||
|
|
||||||
|
@ -260,6 +263,8 @@ class Auth:
|
||||||
opentracing.set_tag("user_id", user_info.user_id)
|
opentracing.set_tag("user_id", user_info.user_id)
|
||||||
if device_id:
|
if device_id:
|
||||||
opentracing.set_tag("device_id", device_id)
|
opentracing.set_tag("device_id", device_id)
|
||||||
|
if user_info.token_owner in self._force_tracing_for_users:
|
||||||
|
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
||||||
|
|
||||||
return requester
|
return requester
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|
|
@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
|
||||||
from synapse.rest.admin import register_servlets_for_media_repo
|
from synapse.rest.admin import register_servlets_for_media_repo
|
||||||
from synapse.rest.client.v1 import events, login, presence, room
|
from synapse.rest.client.v1 import events, login, presence, room
|
||||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||||
|
@ -237,7 +236,6 @@ class GenericWorkerSlavedStore(
|
||||||
DirectoryStore,
|
DirectoryStore,
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
SlavedTransactionStore,
|
|
||||||
SlavedProfileStore,
|
SlavedProfileStore,
|
||||||
SlavedClientIpStore,
|
SlavedClientIpStore,
|
||||||
SlavedFilteringStore,
|
SlavedFilteringStore,
|
||||||
|
|
|
@ -29,9 +29,26 @@ class ExperimentalConfig(Config):
|
||||||
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
||||||
|
|
||||||
# Spaces (MSC1772, MSC2946, MSC3083, etc)
|
# Spaces (MSC1772, MSC2946, MSC3083, etc)
|
||||||
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
|
self.spaces_enabled = experimental.get("spaces_enabled", True) # type: bool
|
||||||
if self.spaces_enabled:
|
if self.spaces_enabled:
|
||||||
KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
|
KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
|
||||||
|
|
||||||
# MSC3026 (busy presence state)
|
# MSC3026 (busy presence state)
|
||||||
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
||||||
|
|
||||||
|
def generate_config_section(self, **kwargs):
|
||||||
|
return """\
|
||||||
|
# Enable experimental features in Synapse.
|
||||||
|
#
|
||||||
|
# Experimental features might break or be removed without a deprecation
|
||||||
|
# period.
|
||||||
|
#
|
||||||
|
experimental_features:
|
||||||
|
# Support for Spaces (MSC1772), it enables the following:
|
||||||
|
#
|
||||||
|
# * The Spaces Summary API (MSC2946).
|
||||||
|
# * Restricting room membership based on space membership (MSC3083).
|
||||||
|
#
|
||||||
|
# Uncomment to disable support for Spaces.
|
||||||
|
#spaces_enabled: false
|
||||||
|
"""
|
||||||
|
|
|
@ -59,7 +59,6 @@ class HomeServerConfig(RootConfig):
|
||||||
config_classes = [
|
config_classes = [
|
||||||
MeowConfig,
|
MeowConfig,
|
||||||
ServerConfig,
|
ServerConfig,
|
||||||
ExperimentalConfig,
|
|
||||||
TlsConfig,
|
TlsConfig,
|
||||||
FederationConfig,
|
FederationConfig,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
|
@ -96,4 +95,5 @@ class HomeServerConfig(RootConfig):
|
||||||
TracerConfig,
|
TracerConfig,
|
||||||
WorkerConfig,
|
WorkerConfig,
|
||||||
RedisConfig,
|
RedisConfig,
|
||||||
|
ExperimentalConfig,
|
||||||
]
|
]
|
||||||
|
|
|
@ -349,4 +349,4 @@ class RegistrationConfig(Config):
|
||||||
|
|
||||||
def read_arguments(self, args):
|
def read_arguments(self, args):
|
||||||
if args.enable_registration is not None:
|
if args.enable_registration is not None:
|
||||||
self.enable_registration = bool(strtobool(str(args.enable_registration)))
|
self.enable_registration = strtobool(str(args.enable_registration))
|
||||||
|
|
|
@ -164,7 +164,13 @@ class SAML2Config(Config):
|
||||||
config_path = saml2_config.get("config_path", None)
|
config_path = saml2_config.get("config_path", None)
|
||||||
if config_path is not None:
|
if config_path is not None:
|
||||||
mod = load_python_module(config_path)
|
mod = load_python_module(config_path)
|
||||||
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
|
config = getattr(mod, "CONFIG", None)
|
||||||
|
if config is None:
|
||||||
|
raise ConfigError(
|
||||||
|
"Config path specified by saml2_config.config_path does not "
|
||||||
|
"have a CONFIG property."
|
||||||
|
)
|
||||||
|
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
|
||||||
|
|
||||||
import saml2.config
|
import saml2.config
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,8 @@ import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
|
||||||
from typing import List, Optional, Pattern
|
from typing import List, Optional, Pattern
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
|
||||||
|
|
||||||
from OpenSSL import SSL, crypto
|
from OpenSSL import SSL, crypto
|
||||||
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
|
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
|
||||||
|
|
||||||
|
@ -83,13 +80,6 @@ class TlsConfig(Config):
|
||||||
"configured."
|
"configured."
|
||||||
)
|
)
|
||||||
|
|
||||||
self._original_tls_fingerprints = config.get("tls_fingerprints", [])
|
|
||||||
|
|
||||||
if self._original_tls_fingerprints is None:
|
|
||||||
self._original_tls_fingerprints = []
|
|
||||||
|
|
||||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
|
||||||
|
|
||||||
# Whether to verify certificates on outbound federation traffic
|
# Whether to verify certificates on outbound federation traffic
|
||||||
self.federation_verify_certificates = config.get(
|
self.federation_verify_certificates = config.get(
|
||||||
"federation_verify_certificates", True
|
"federation_verify_certificates", True
|
||||||
|
@ -248,19 +238,6 @@ class TlsConfig(Config):
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
|
||||||
|
|
||||||
if self.tls_certificate:
|
|
||||||
# Check that our own certificate is included in the list of fingerprints
|
|
||||||
# and include it if it is not.
|
|
||||||
x509_certificate_bytes = crypto.dump_certificate(
|
|
||||||
crypto.FILETYPE_ASN1, self.tls_certificate
|
|
||||||
)
|
|
||||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
|
||||||
sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints}
|
|
||||||
if sha256_fingerprint not in sha256_fingerprints:
|
|
||||||
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
|
|
||||||
|
|
||||||
def generate_config_section(
|
def generate_config_section(
|
||||||
self,
|
self,
|
||||||
config_dir_path,
|
config_dir_path,
|
||||||
|
@ -443,33 +420,6 @@ class TlsConfig(Config):
|
||||||
# If unspecified, we will use CONFDIR/client.key.
|
# If unspecified, we will use CONFDIR/client.key.
|
||||||
#
|
#
|
||||||
account_key_file: %(default_acme_account_file)s
|
account_key_file: %(default_acme_account_file)s
|
||||||
|
|
||||||
# List of allowed TLS fingerprints for this server to publish along
|
|
||||||
# with the signing keys for this server. Other matrix servers that
|
|
||||||
# make HTTPS requests to this server will check that the TLS
|
|
||||||
# certificates returned by this server match one of the fingerprints.
|
|
||||||
#
|
|
||||||
# Synapse automatically adds the fingerprint of its own certificate
|
|
||||||
# to the list. So if federation traffic is handled directly by synapse
|
|
||||||
# then no modification to the list is required.
|
|
||||||
#
|
|
||||||
# If synapse is run behind a load balancer that handles the TLS then it
|
|
||||||
# will be necessary to add the fingerprints of the certificates used by
|
|
||||||
# the loadbalancers to this list if they are different to the one
|
|
||||||
# synapse is using.
|
|
||||||
#
|
|
||||||
# Homeservers are permitted to cache the list of TLS fingerprints
|
|
||||||
# returned in the key responses up to the "valid_until_ts" returned in
|
|
||||||
# key. It may be necessary to publish the fingerprints of a new
|
|
||||||
# certificate and wait until the "valid_until_ts" of the previous key
|
|
||||||
# responses have passed before deploying it.
|
|
||||||
#
|
|
||||||
# You can calculate a fingerprint from a given TLS listener via:
|
|
||||||
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
|
||||||
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
|
||||||
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
|
||||||
#
|
|
||||||
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
|
||||||
"""
|
"""
|
||||||
# Lowercase the string representation of boolean values
|
# Lowercase the string representation of boolean values
|
||||||
% {
|
% {
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
from synapse.python_dependencies import DependencyException, check_requirements
|
from synapse.python_dependencies import DependencyException, check_requirements
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
@ -32,6 +34,8 @@ class TracerConfig(Config):
|
||||||
{"sampler": {"type": "const", "param": 1}, "logging": False},
|
{"sampler": {"type": "const", "param": 1}, "logging": False},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.force_tracing_for_users: Set[str] = set()
|
||||||
|
|
||||||
if not self.opentracer_enabled:
|
if not self.opentracer_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -48,6 +52,19 @@ class TracerConfig(Config):
|
||||||
if not isinstance(self.opentracer_whitelist, list):
|
if not isinstance(self.opentracer_whitelist, list):
|
||||||
raise ConfigError("Tracer homeserver_whitelist config is malformed")
|
raise ConfigError("Tracer homeserver_whitelist config is malformed")
|
||||||
|
|
||||||
|
force_tracing_for_users = opentracing_config.get("force_tracing_for_users", [])
|
||||||
|
if not isinstance(force_tracing_for_users, list):
|
||||||
|
raise ConfigError(
|
||||||
|
"Expected a list", ("opentracing", "force_tracing_for_users")
|
||||||
|
)
|
||||||
|
for i, u in enumerate(force_tracing_for_users):
|
||||||
|
if not isinstance(u, str):
|
||||||
|
raise ConfigError(
|
||||||
|
"Expected a string",
|
||||||
|
("opentracing", "force_tracing_for_users", f"index {i}"),
|
||||||
|
)
|
||||||
|
self.force_tracing_for_users.add(u)
|
||||||
|
|
||||||
def generate_config_section(cls, **kwargs):
|
def generate_config_section(cls, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
## Opentracing ##
|
## Opentracing ##
|
||||||
|
@ -64,7 +81,8 @@ class TracerConfig(Config):
|
||||||
#enabled: true
|
#enabled: true
|
||||||
|
|
||||||
# The list of homeservers we wish to send and receive span contexts and span baggage.
|
# The list of homeservers we wish to send and receive span contexts and span baggage.
|
||||||
# See docs/opentracing.rst
|
# See docs/opentracing.rst.
|
||||||
|
#
|
||||||
# This is a list of regexes which are matched against the server_name of the
|
# This is a list of regexes which are matched against the server_name of the
|
||||||
# homeserver.
|
# homeserver.
|
||||||
#
|
#
|
||||||
|
@ -73,19 +91,26 @@ class TracerConfig(Config):
|
||||||
#homeserver_whitelist:
|
#homeserver_whitelist:
|
||||||
# - ".*"
|
# - ".*"
|
||||||
|
|
||||||
|
# A list of the matrix IDs of users whose requests will always be traced,
|
||||||
|
# even if the tracing system would otherwise drop the traces due to
|
||||||
|
# probabilistic sampling.
|
||||||
|
#
|
||||||
|
# By default, the list is empty.
|
||||||
|
#
|
||||||
|
#force_tracing_for_users:
|
||||||
|
# - "@user1:server_name"
|
||||||
|
# - "@user2:server_name"
|
||||||
|
|
||||||
# Jaeger can be configured to sample traces at different rates.
|
# Jaeger can be configured to sample traces at different rates.
|
||||||
# All configuration options provided by Jaeger can be set here.
|
# All configuration options provided by Jaeger can be set here.
|
||||||
# Jaeger's configuration mostly related to trace sampling which
|
# Jaeger's configuration is mostly related to trace sampling which
|
||||||
# is documented here:
|
# is documented here:
|
||||||
# https://www.jaegertracing.io/docs/1.13/sampling/.
|
# https://www.jaegertracing.io/docs/latest/sampling/.
|
||||||
#
|
#
|
||||||
#jaeger_config:
|
#jaeger_config:
|
||||||
# sampler:
|
# sampler:
|
||||||
# type: const
|
# type: const
|
||||||
# param: 1
|
# param: 1
|
||||||
|
|
||||||
# Logging whether spans were started and reported
|
|
||||||
#
|
|
||||||
# logging:
|
# logging:
|
||||||
# false
|
# false
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -17,7 +17,7 @@ import abc
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import (
|
from signedjson.key import (
|
||||||
|
@ -42,6 +42,8 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.config.key import TrustedKeyServer
|
from synapse.config.key import TrustedKeyServer
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.utils import prune_event_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
|
@ -69,7 +71,11 @@ class VerifyJsonRequest:
|
||||||
Attributes:
|
Attributes:
|
||||||
server_name: The name of the server to verify against.
|
server_name: The name of the server to verify against.
|
||||||
|
|
||||||
json_object: The JSON object to verify.
|
get_json_object: A callback to fetch the JSON object to verify.
|
||||||
|
A callback is used to allow deferring the creation of the JSON
|
||||||
|
object to verify until needed, e.g. for events we can defer
|
||||||
|
creating the redacted copy. This reduces the memory usage when
|
||||||
|
there are large numbers of in flight requests.
|
||||||
|
|
||||||
minimum_valid_until_ts: time at which we require the signing key to
|
minimum_valid_until_ts: time at which we require the signing key to
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
@ -88,14 +94,50 @@ class VerifyJsonRequest:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
server_name = attr.ib(type=str)
|
server_name = attr.ib(type=str)
|
||||||
json_object = attr.ib(type=JsonDict)
|
get_json_object = attr.ib(type=Callable[[], JsonDict])
|
||||||
minimum_valid_until_ts = attr.ib(type=int)
|
minimum_valid_until_ts = attr.ib(type=int)
|
||||||
request_name = attr.ib(type=str)
|
request_name = attr.ib(type=str)
|
||||||
key_ids = attr.ib(init=False, type=List[str])
|
key_ids = attr.ib(type=List[str])
|
||||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
@staticmethod
|
||||||
self.key_ids = signature_ids(self.json_object, self.server_name)
|
def from_json_object(
|
||||||
|
server_name: str,
|
||||||
|
json_object: JsonDict,
|
||||||
|
minimum_valid_until_ms: int,
|
||||||
|
request_name: str,
|
||||||
|
):
|
||||||
|
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
||||||
|
object for the given server.
|
||||||
|
"""
|
||||||
|
key_ids = signature_ids(json_object, server_name)
|
||||||
|
return VerifyJsonRequest(
|
||||||
|
server_name,
|
||||||
|
lambda: json_object,
|
||||||
|
minimum_valid_until_ms,
|
||||||
|
request_name=request_name,
|
||||||
|
key_ids=key_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_event(
|
||||||
|
server_name: str,
|
||||||
|
event: EventBase,
|
||||||
|
minimum_valid_until_ms: int,
|
||||||
|
):
|
||||||
|
"""Create a VerifyJsonRequest to verify all signatures on an event
|
||||||
|
object for the given server.
|
||||||
|
"""
|
||||||
|
key_ids = list(event.signatures.get(server_name, []))
|
||||||
|
return VerifyJsonRequest(
|
||||||
|
server_name,
|
||||||
|
# We defer creating the redacted json object, as it uses a lot more
|
||||||
|
# memory than the Event object itself.
|
||||||
|
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
||||||
|
minimum_valid_until_ms,
|
||||||
|
request_name=event.event_id,
|
||||||
|
key_ids=key_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KeyLookupError(ValueError):
|
class KeyLookupError(ValueError):
|
||||||
|
@ -147,8 +189,13 @@ class Keyring:
|
||||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
Deferred[None]: completes if the the object was correctly signed, otherwise
|
||||||
errbacks with an error
|
errbacks with an error
|
||||||
"""
|
"""
|
||||||
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
request = VerifyJsonRequest.from_json_object(
|
||||||
requests = (req,)
|
server_name,
|
||||||
|
json_object,
|
||||||
|
validity_time,
|
||||||
|
request_name,
|
||||||
|
)
|
||||||
|
requests = (request,)
|
||||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
||||||
|
|
||||||
def verify_json_objects_for_server(
|
def verify_json_objects_for_server(
|
||||||
|
@ -175,10 +222,41 @@ class Keyring:
|
||||||
logcontext.
|
logcontext.
|
||||||
"""
|
"""
|
||||||
return self._verify_objects(
|
return self._verify_objects(
|
||||||
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
VerifyJsonRequest.from_json_object(
|
||||||
|
server_name, json_object, validity_time, request_name
|
||||||
|
)
|
||||||
for server_name, json_object, validity_time, request_name in server_and_json
|
for server_name, json_object, validity_time, request_name in server_and_json
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def verify_events_for_server(
|
||||||
|
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
||||||
|
) -> List[defer.Deferred]:
|
||||||
|
"""Bulk verification of signatures on events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_and_events:
|
||||||
|
Iterable of `(server_name, event, validity_time)` tuples.
|
||||||
|
|
||||||
|
`server_name` is which server we are verifying the signature for
|
||||||
|
on the event.
|
||||||
|
|
||||||
|
`event` is the event that we'll verify the signatures of for
|
||||||
|
the given `server_name`.
|
||||||
|
|
||||||
|
`validity_time` is a timestamp at which the signing key must be
|
||||||
|
valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||||
|
or failure to verify each event's signature for the given
|
||||||
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
|
logcontext.
|
||||||
|
"""
|
||||||
|
return self._verify_objects(
|
||||||
|
VerifyJsonRequest.from_event(server_name, event, validity_time)
|
||||||
|
for server_name, event, validity_time in server_and_events
|
||||||
|
)
|
||||||
|
|
||||||
def _verify_objects(
|
def _verify_objects(
|
||||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
self, verify_requests: Iterable[VerifyJsonRequest]
|
||||||
) -> List[defer.Deferred]:
|
) -> List[defer.Deferred]:
|
||||||
|
@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
_, key_id, verify_key = await verify_request.key_ready
|
_, key_id, verify_key = await verify_request.key_ready
|
||||||
|
|
||||||
json_object = verify_request.json_object
|
json_object = verify_request.get_json_object()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
|
|
@ -137,11 +137,7 @@ class FederationBase:
|
||||||
return deferreds
|
return deferreds
|
||||||
|
|
||||||
|
|
||||||
class PduToCheckSig(
|
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
|
||||||
namedtuple(
|
|
||||||
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
|
|
||||||
)
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
|
||||||
pdus_to_check = [
|
pdus_to_check = [
|
||||||
PduToCheckSig(
|
PduToCheckSig(
|
||||||
pdu=p,
|
pdu=p,
|
||||||
redacted_pdu_json=prune_event(p).get_pdu_json(),
|
|
||||||
sender_domain=get_domain_from_id(p.sender),
|
sender_domain=get_domain_from_id(p.sender),
|
||||||
deferreds=[],
|
deferreds=[],
|
||||||
)
|
)
|
||||||
|
@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
|
||||||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||||
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
||||||
|
|
||||||
more_deferreds = keyring.verify_json_objects_for_server(
|
more_deferreds = keyring.verify_events_for_server(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
p.sender_domain,
|
p.sender_domain,
|
||||||
p.redacted_pdu_json,
|
p.pdu,
|
||||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||||
p.pdu.event_id,
|
|
||||||
)
|
)
|
||||||
for p in pdus_to_check_sender
|
for p in pdus_to_check_sender
|
||||||
]
|
]
|
||||||
|
@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
|
||||||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
||||||
]
|
]
|
||||||
|
|
||||||
more_deferreds = keyring.verify_json_objects_for_server(
|
more_deferreds = keyring.verify_events_for_server(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
get_domain_from_id(p.pdu.event_id),
|
get_domain_from_id(p.pdu.event_id),
|
||||||
p.redacted_pdu_json,
|
p.pdu,
|
||||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||||
p.pdu.event_id,
|
|
||||||
)
|
)
|
||||||
for p in pdus_to_check_event_id
|
for p in pdus_to_check_event_id
|
||||||
]
|
]
|
||||||
|
|
|
@ -55,6 +55,7 @@ from synapse.api.room_versions import (
|
||||||
)
|
)
|
||||||
from synapse.events import EventBase, builder
|
from synapse.events import EventBase, builder
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
|
from synapse.federation.transport.client import SendJoinResponse
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
|
@ -665,19 +666,10 @@ class FederationClient(FederationBase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def send_request(destination) -> Dict[str, Any]:
|
async def send_request(destination) -> Dict[str, Any]:
|
||||||
content = await self._do_send_join(destination, pdu)
|
response = await self._do_send_join(room_version, destination, pdu)
|
||||||
|
|
||||||
logger.debug("Got content: %s", content)
|
state = response.state
|
||||||
|
auth_chain = response.auth_events
|
||||||
state = [
|
|
||||||
event_from_pdu_json(p, room_version, outlier=True)
|
|
||||||
for p in content.get("state", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
auth_chain = [
|
|
||||||
event_from_pdu_json(p, room_version, outlier=True)
|
|
||||||
for p in content.get("auth_chain", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
|
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
|
||||||
|
|
||||||
|
@ -752,11 +744,14 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
return await self._try_destination_list("send_join", destinations, send_request)
|
return await self._try_destination_list("send_join", destinations, send_request)
|
||||||
|
|
||||||
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
|
async def _do_send_join(
|
||||||
|
self, room_version: RoomVersion, destination: str, pdu: EventBase
|
||||||
|
) -> SendJoinResponse:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.transport_layer.send_join_v2(
|
return await self.transport_layer.send_join_v2(
|
||||||
|
room_version=room_version,
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
|
@ -771,17 +766,14 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
||||||
|
|
||||||
resp = await self.transport_layer.send_join_v1(
|
return await self.transport_layer.send_join_v1(
|
||||||
|
room_version=room_version,
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
content=pdu.get_pdu_json(time_now),
|
content=pdu.get_pdu_json(time_now),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We expect the v1 API to respond with [200, content], so we only return the
|
|
||||||
# content.
|
|
||||||
return resp[1]
|
|
||||||
|
|
||||||
async def send_invite(
|
async def send_invite(
|
||||||
self,
|
self,
|
||||||
destination: str,
|
destination: str,
|
||||||
|
|
|
@ -17,13 +17,19 @@ import logging
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import attr
|
||||||
|
import ijson
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.api.urls import (
|
from synapse.api.urls import (
|
||||||
FEDERATION_UNSTABLE_PREFIX,
|
FEDERATION_UNSTABLE_PREFIX,
|
||||||
FEDERATION_V1_PREFIX,
|
FEDERATION_V1_PREFIX,
|
||||||
FEDERATION_V2_PREFIX,
|
FEDERATION_V2_PREFIX,
|
||||||
)
|
)
|
||||||
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
|
from synapse.http.matrixfederationclient import ByteParser
|
||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -240,21 +246,36 @@ class TransportLayerClient:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
async def send_join_v1(self, destination, room_id, event_id, content):
|
async def send_join_v1(
|
||||||
|
self,
|
||||||
|
room_version,
|
||||||
|
destination,
|
||||||
|
room_id,
|
||||||
|
event_id,
|
||||||
|
content,
|
||||||
|
) -> "SendJoinResponse":
|
||||||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||||
|
|
||||||
response = await self.client.put_json(
|
response = await self.client.put_json(
|
||||||
destination=destination, path=path, data=content
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
parser=SendJoinParser(room_version, v1_api=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
async def send_join_v2(self, destination, room_id, event_id, content):
|
async def send_join_v2(
|
||||||
|
self, room_version, destination, room_id, event_id, content
|
||||||
|
) -> "SendJoinResponse":
|
||||||
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
||||||
|
|
||||||
response = await self.client.put_json(
|
response = await self.client.put_json(
|
||||||
destination=destination, path=path, data=content
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
parser=SendJoinParser(room_version, v1_api=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -1053,3 +1074,59 @@ def _create_v2_path(path, *args):
|
||||||
str
|
str
|
||||||
"""
|
"""
|
||||||
return _create_path(FEDERATION_V2_PREFIX, path, *args)
|
return _create_path(FEDERATION_V2_PREFIX, path, *args)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class SendJoinResponse:
|
||||||
|
"""The parsed response of a `/send_join` request."""
|
||||||
|
|
||||||
|
auth_events: List[EventBase]
|
||||||
|
state: List[EventBase]
|
||||||
|
|
||||||
|
|
||||||
|
@ijson.coroutine
|
||||||
|
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
|
||||||
|
"""Helper function for use with `ijson.items_coro` to parse an array of
|
||||||
|
events and add them to the given list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
obj = yield
|
||||||
|
event = make_event_from_dict(obj, room_version)
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
|
||||||
|
class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||||
|
"""A parser for the response to `/send_join` requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_version: The version of the room.
|
||||||
|
v1_api: Whether the response is in the v1 format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
|
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
||||||
|
self._response = SendJoinResponse([], [])
|
||||||
|
|
||||||
|
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
||||||
|
# prefixing with `item.*`.
|
||||||
|
prefix = "item." if v1_api else ""
|
||||||
|
|
||||||
|
self._coro_state = ijson.items_coro(
|
||||||
|
_event_list_parser(room_version, self._response.state),
|
||||||
|
prefix + "state.item",
|
||||||
|
)
|
||||||
|
self._coro_auth = ijson.items_coro(
|
||||||
|
_event_list_parser(room_version, self._response.auth_events),
|
||||||
|
prefix + "auth_chain.item",
|
||||||
|
)
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> int:
|
||||||
|
self._coro_state.send(data)
|
||||||
|
self._coro_auth.send(data)
|
||||||
|
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
def finish(self) -> SendJoinResponse:
|
||||||
|
return self._response
|
||||||
|
|
|
@ -160,7 +160,7 @@ class Authenticator:
|
||||||
# If we get a valid signed request from the other side, its probably
|
# If we get a valid signed request from the other side, its probably
|
||||||
# alive
|
# alive
|
||||||
retry_timings = await self.store.get_destination_retry_timings(origin)
|
retry_timings = await self.store.get_destination_retry_timings(origin)
|
||||||
if retry_timings and retry_timings["retry_last_ts"]:
|
if retry_timings and retry_timings.retry_last_ts:
|
||||||
run_in_background(self._reset_retry_timings, origin)
|
run_in_background(self._reset_retry_timings, origin)
|
||||||
|
|
||||||
return origin
|
return origin
|
||||||
|
@ -1428,7 +1428,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, await self.handler.federation_space_summary(
|
return 200, await self.handler.federation_space_summary(
|
||||||
room_id, suggested_only, max_rooms_per_space, exclude_rooms
|
origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,9 @@
|
||||||
import email.mime.multipart
|
import email.mime.multipart
|
||||||
import email.utils
|
import email.utils
|
||||||
import logging
|
import logging
|
||||||
from email.mime.multipart import MIMEMultipart
|
|
||||||
from email.mime.text import MIMEText
|
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, SynapseError
|
from synapse.api.errors import StoreError, SynapseError
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
|
@ -36,9 +33,11 @@ class AccountValidityHandler:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.sendmail = self.hs.get_sendmail()
|
self.send_email_handler = self.hs.get_send_email_handler()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
|
||||||
|
self._app_name = self.hs.config.email_app_name
|
||||||
|
|
||||||
self._account_validity_enabled = (
|
self._account_validity_enabled = (
|
||||||
hs.config.account_validity.account_validity_enabled
|
hs.config.account_validity.account_validity_enabled
|
||||||
)
|
)
|
||||||
|
@ -63,23 +62,10 @@ class AccountValidityHandler:
|
||||||
self._template_text = (
|
self._template_text = (
|
||||||
hs.config.account_validity.account_validity_template_text
|
hs.config.account_validity.account_validity_template_text
|
||||||
)
|
)
|
||||||
account_validity_renew_email_subject = (
|
self._renew_email_subject = (
|
||||||
hs.config.account_validity.account_validity_renew_email_subject
|
hs.config.account_validity.account_validity_renew_email_subject
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
app_name = hs.config.email_app_name
|
|
||||||
|
|
||||||
self._subject = account_validity_renew_email_subject % {"app": app_name}
|
|
||||||
|
|
||||||
self._from_string = hs.config.email_notif_from % {"app": app_name}
|
|
||||||
except Exception:
|
|
||||||
# If substitution failed, fall back to the bare strings.
|
|
||||||
self._subject = account_validity_renew_email_subject
|
|
||||||
self._from_string = hs.config.email_notif_from
|
|
||||||
|
|
||||||
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
|
||||||
|
|
||||||
# Check the renewal emails to send and send them every 30min.
|
# Check the renewal emails to send and send them every 30min.
|
||||||
if hs.config.run_background_tasks:
|
if hs.config.run_background_tasks:
|
||||||
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
||||||
|
@ -159,38 +145,17 @@ class AccountValidityHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
html_text = self._template_html.render(**template_vars)
|
html_text = self._template_html.render(**template_vars)
|
||||||
html_part = MIMEText(html_text, "html", "utf8")
|
|
||||||
|
|
||||||
plain_text = self._template_text.render(**template_vars)
|
plain_text = self._template_text.render(**template_vars)
|
||||||
text_part = MIMEText(plain_text, "plain", "utf8")
|
|
||||||
|
|
||||||
for address in addresses:
|
for address in addresses:
|
||||||
raw_to = email.utils.parseaddr(address)[1]
|
raw_to = email.utils.parseaddr(address)[1]
|
||||||
|
|
||||||
multipart_msg = MIMEMultipart("alternative")
|
await self.send_email_handler.send_email(
|
||||||
multipart_msg["Subject"] = self._subject
|
email_address=raw_to,
|
||||||
multipart_msg["From"] = self._from_string
|
subject=self._renew_email_subject,
|
||||||
multipart_msg["To"] = address
|
app_name=self._app_name,
|
||||||
multipart_msg["Date"] = email.utils.formatdate()
|
html=html_text,
|
||||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
text=plain_text,
|
||||||
multipart_msg.attach(text_part)
|
|
||||||
multipart_msg.attach(html_part)
|
|
||||||
|
|
||||||
logger.info("Sending renewal email to %s", address)
|
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
|
||||||
self.sendmail(
|
|
||||||
self.hs.config.email_smtp_host,
|
|
||||||
self._raw_from,
|
|
||||||
raw_to,
|
|
||||||
multipart_msg.as_string().encode("utf8"),
|
|
||||||
reactor=self.hs.get_reactor(),
|
|
||||||
port=self.hs.config.email_smtp_port,
|
|
||||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
|
||||||
username=self.hs.config.email_smtp_user,
|
|
||||||
password=self.hs.config.email_smtp_pass,
|
|
||||||
requireTransportSecurity=self.hs.config.require_transport_security,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||||
|
|
|
@ -11,10 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Collection, Optional
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, JoinRules
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -29,46 +31,104 @@ class EventAuthHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
|
|
||||||
async def can_join_without_invite(
|
async def check_restricted_join_rules(
|
||||||
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
|
self,
|
||||||
) -> bool:
|
state_ids: StateMap[str],
|
||||||
|
room_version: RoomVersion,
|
||||||
|
user_id: str,
|
||||||
|
prev_member_event: Optional[EventBase],
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Check whether a user can join a room without an invite.
|
Check whether a user can join a room without an invite due to restricted join rules.
|
||||||
|
|
||||||
When joining a room with restricted joined rules (as defined in MSC3083),
|
When joining a room with restricted joined rules (as defined in MSC3083),
|
||||||
the membership of spaces must be checked during join.
|
the membership of spaces must be checked during a room join.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_ids: The state of the room as it currently is.
|
state_ids: The state of the room as it currently is.
|
||||||
room_version: The room version of the room being joined.
|
room_version: The room version of the room being joined.
|
||||||
user_id: The user joining the room.
|
user_id: The user joining the room.
|
||||||
|
prev_member_event: The current membership event for this user.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError if the user cannot join the room.
|
||||||
|
"""
|
||||||
|
# If the member is invited or currently joined, then nothing to do.
|
||||||
|
if prev_member_event and (
|
||||||
|
prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# This is not a room with a restricted join rule, so we don't need to do the
|
||||||
|
# restricted room specific checks.
|
||||||
|
#
|
||||||
|
# Note: We'll be applying the standard join rule checks later, which will
|
||||||
|
# catch the cases of e.g. trying to join private rooms without an invite.
|
||||||
|
if not await self.has_restricted_join_rules(state_ids, room_version):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the spaces which allow access to this room and check if the user is
|
||||||
|
# in any of them.
|
||||||
|
allowed_spaces = await self.get_spaces_that_allow_join(state_ids)
|
||||||
|
if not await self.is_user_in_rooms(allowed_spaces, user_id):
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"You do not belong to any of the required spaces to join this room.",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def has_restricted_join_rules(
|
||||||
|
self, state_ids: StateMap[str], room_version: RoomVersion
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Return if the room has the proper join rules set for access via spaces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_ids: The state of the room as it currently is.
|
||||||
|
room_version: The room version of the room to query.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the user can join the room, false otherwise.
|
True if the proper room version and join rules are set for restricted access.
|
||||||
"""
|
"""
|
||||||
# This only applies to room versions which support the new join rule.
|
# This only applies to room versions which support the new join rule.
|
||||||
if not room_version.msc3083_join_rules:
|
if not room_version.msc3083_join_rules:
|
||||||
return True
|
return False
|
||||||
|
|
||||||
# If there's no join rule, then it defaults to invite (so this doesn't apply).
|
# If there's no join rule, then it defaults to invite (so this doesn't apply).
|
||||||
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
|
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
|
||||||
if not join_rules_event_id:
|
if not join_rules_event_id:
|
||||||
return True
|
return False
|
||||||
|
|
||||||
|
# If the join rule is not restricted, this doesn't apply.
|
||||||
|
join_rules_event = await self._store.get_event(join_rules_event_id)
|
||||||
|
return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED
|
||||||
|
|
||||||
|
async def get_spaces_that_allow_join(
|
||||||
|
self, state_ids: StateMap[str]
|
||||||
|
) -> Collection[str]:
|
||||||
|
"""
|
||||||
|
Generate a list of spaces which allow access to a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_ids: The state of the room as it currently is.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A collection of spaces which provide membership to the room.
|
||||||
|
"""
|
||||||
|
# If there's no join rule, then it defaults to invite (so this doesn't apply).
|
||||||
|
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
|
||||||
|
if not join_rules_event_id:
|
||||||
|
return ()
|
||||||
|
|
||||||
# If the join rule is not restricted, this doesn't apply.
|
# If the join rule is not restricted, this doesn't apply.
|
||||||
join_rules_event = await self._store.get_event(join_rules_event_id)
|
join_rules_event = await self._store.get_event(join_rules_event_id)
|
||||||
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If allowed is of the wrong form, then only allow invited users.
|
# If allowed is of the wrong form, then only allow invited users.
|
||||||
allowed_spaces = join_rules_event.content.get("allow", [])
|
allowed_spaces = join_rules_event.content.get("allow", [])
|
||||||
if not isinstance(allowed_spaces, list):
|
if not isinstance(allowed_spaces, list):
|
||||||
return False
|
return ()
|
||||||
|
|
||||||
# Get the list of joined rooms and see if there's an overlap.
|
|
||||||
joined_rooms = await self._store.get_rooms_for_user(user_id)
|
|
||||||
|
|
||||||
# Pull out the other room IDs, invalid data gets filtered.
|
# Pull out the other room IDs, invalid data gets filtered.
|
||||||
|
result = []
|
||||||
for space in allowed_spaces:
|
for space in allowed_spaces:
|
||||||
if not isinstance(space, dict):
|
if not isinstance(space, dict):
|
||||||
continue
|
continue
|
||||||
|
@ -77,10 +137,31 @@ class EventAuthHandler:
|
||||||
if not isinstance(space_id, str):
|
if not isinstance(space_id, str):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# The user was joined to one of the spaces specified, they can join
|
result.append(space_id)
|
||||||
# this room!
|
|
||||||
if space_id in joined_rooms:
|
return result
|
||||||
|
|
||||||
|
async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether a user is a member of any of the provided rooms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_ids: The rooms to check for membership.
|
||||||
|
user_id: The user to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the user is in any of the rooms, false otherwise.
|
||||||
|
"""
|
||||||
|
if not room_ids:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get the list of joined rooms and see if there's an overlap.
|
||||||
|
joined_rooms = await self._store.get_rooms_for_user(user_id)
|
||||||
|
|
||||||
|
# Check each room and see if the user is in it.
|
||||||
|
for room_id in room_ids:
|
||||||
|
if room_id in joined_rooms:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# The user was not in any of the required spaces.
|
# The user was not in any of the rooms.
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1668,27 +1668,16 @@ class FederationHandler(BaseHandler):
|
||||||
# Check if the user is already in the room or invited to the room.
|
# Check if the user is already in the room or invited to the room.
|
||||||
user_id = event.state_key
|
user_id = event.state_key
|
||||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
newly_joined = True
|
prev_member_event = None
|
||||||
user_is_invited = False
|
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
|
||||||
user_is_invited = prev_member_event.membership == Membership.INVITE
|
|
||||||
|
|
||||||
# If the member is not already in the room, and not invited, check if
|
# Check if the member should be allowed access via membership in a space.
|
||||||
# they should be allowed access via membership in a space.
|
await self._event_auth_handler.check_restricted_join_rules(
|
||||||
if (
|
|
||||||
newly_joined
|
|
||||||
and not user_is_invited
|
|
||||||
and not await self._event_auth_handler.can_join_without_invite(
|
|
||||||
prev_state_ids,
|
prev_state_ids,
|
||||||
event.room_version,
|
event.room_version,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
prev_member_event,
|
||||||
):
|
|
||||||
raise AuthError(
|
|
||||||
403,
|
|
||||||
"You do not belong to any of the required spaces to join this room.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Persist the event.
|
# Persist the event.
|
||||||
|
|
|
@ -222,9 +222,21 @@ class BasePresenceHandler(abc.ABC):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def set_state(
|
async def set_state(
|
||||||
self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
|
self,
|
||||||
|
target_user: UserID,
|
||||||
|
state: JsonDict,
|
||||||
|
ignore_status_msg: bool = False,
|
||||||
|
force_notify: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the presence state of the user. """
|
"""Set the presence state of the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_user: The ID of the user to set the presence state of.
|
||||||
|
state: The presence state as a JSON dictionary.
|
||||||
|
ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
|
||||||
|
If False, the user's current status will be updated.
|
||||||
|
force_notify: Whether to force notification of the update to clients.
|
||||||
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def bump_presence_active_time(self, user: UserID):
|
async def bump_presence_active_time(self, user: UserID):
|
||||||
|
@ -296,6 +308,51 @@ class BasePresenceHandler(abc.ABC):
|
||||||
for destinations, states in hosts_and_states:
|
for destinations, states in hosts_and_states:
|
||||||
self._federation.send_presence_to_destinations(states, destinations)
|
self._federation.send_presence_to_destinations(states, destinations)
|
||||||
|
|
||||||
|
async def send_full_presence_to_users(self, user_ids: Collection[str]):
|
||||||
|
"""
|
||||||
|
Adds to the list of users who should receive a full snapshot of presence
|
||||||
|
upon their next sync. Note that this only works for local users.
|
||||||
|
|
||||||
|
Then, grabs the current presence state for a given set of users and adds it
|
||||||
|
to the top of the presence stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids: The IDs of the local users to send full presence to.
|
||||||
|
"""
|
||||||
|
# Retrieve one of the users from the given set
|
||||||
|
if not user_ids:
|
||||||
|
raise Exception(
|
||||||
|
"send_full_presence_to_users must be called with at least one user"
|
||||||
|
)
|
||||||
|
user_id = next(iter(user_ids))
|
||||||
|
|
||||||
|
# Mark all users as receiving full presence on their next sync
|
||||||
|
await self.store.add_users_to_send_full_presence_to(user_ids)
|
||||||
|
|
||||||
|
# Add a new entry to the presence stream. Since we use stream tokens to determine whether a
|
||||||
|
# local user should receive a full snapshot of presence when they sync, we need to bump the
|
||||||
|
# presence stream so that subsequent syncs with no presence activity in between won't result
|
||||||
|
# in the client receiving multiple full snapshots of presence.
|
||||||
|
#
|
||||||
|
# If we bump the stream ID, then the user will get a higher stream token next sync, and thus
|
||||||
|
# correctly won't receive a second snapshot.
|
||||||
|
|
||||||
|
# Get the current presence state for one of the users (defaults to offline if not found)
|
||||||
|
current_presence_state = await self.get_state(UserID.from_string(user_id))
|
||||||
|
|
||||||
|
# Convert the UserPresenceState object into a serializable dict
|
||||||
|
state = {
|
||||||
|
"presence": current_presence_state.state,
|
||||||
|
"status_message": current_presence_state.status_msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Copy the presence state to the tip of the presence stream.
|
||||||
|
|
||||||
|
# We set force_notify=True here so that this presence update is guaranteed to
|
||||||
|
# increment the presence stream ID (which resending the current user's presence
|
||||||
|
# otherwise would not do).
|
||||||
|
await self.set_state(UserID.from_string(user_id), state, force_notify=True)
|
||||||
|
|
||||||
|
|
||||||
class _NullContextManager(ContextManager[None]):
|
class _NullContextManager(ContextManager[None]):
|
||||||
"""A context manager which does nothing."""
|
"""A context manager which does nothing."""
|
||||||
|
@ -480,8 +537,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
target_user: UserID,
|
target_user: UserID,
|
||||||
state: JsonDict,
|
state: JsonDict,
|
||||||
ignore_status_msg: bool = False,
|
ignore_status_msg: bool = False,
|
||||||
|
force_notify: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the presence state of the user."""
|
"""Set the presence state of the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_user: The ID of the user to set the presence state of.
|
||||||
|
state: The presence state as a JSON dictionary.
|
||||||
|
ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
|
||||||
|
If False, the user's current status will be updated.
|
||||||
|
force_notify: Whether to force notification of the update to clients.
|
||||||
|
"""
|
||||||
presence = state["presence"]
|
presence = state["presence"]
|
||||||
|
|
||||||
valid_presence = (
|
valid_presence = (
|
||||||
|
@ -508,6 +574,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
state=state,
|
state=state,
|
||||||
ignore_status_msg=ignore_status_msg,
|
ignore_status_msg=ignore_status_msg,
|
||||||
|
force_notify=force_notify,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def bump_presence_active_time(self, user: UserID) -> None:
|
async def bump_presence_active_time(self, user: UserID) -> None:
|
||||||
|
@ -677,13 +744,19 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
[self.user_to_current_state[user_id] for user_id in unpersisted]
|
[self.user_to_current_state[user_id] for user_id in unpersisted]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
|
async def _update_states(
|
||||||
|
self, new_states: Iterable[UserPresenceState], force_notify: bool = False
|
||||||
|
) -> None:
|
||||||
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
||||||
the notifier and federation if and only if the changed presence state
|
the notifier and federation if and only if the changed presence state
|
||||||
should be sent to clients/servers.
|
should be sent to clients/servers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_states: The new user presence state updates to process.
|
new_states: The new user presence state updates to process.
|
||||||
|
force_notify: Whether to force notifying clients of this presence state update,
|
||||||
|
even if it doesn't change the state of a user's presence (e.g online -> online).
|
||||||
|
This is currently used to bump the max presence stream ID without changing any
|
||||||
|
user's presence (see PresenceHandler.add_users_to_send_full_presence_to).
|
||||||
"""
|
"""
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
@ -720,6 +793,9 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
now=now,
|
now=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if force_notify:
|
||||||
|
should_notify = True
|
||||||
|
|
||||||
self.user_to_current_state[user_id] = new_state
|
self.user_to_current_state[user_id] = new_state
|
||||||
|
|
||||||
if should_notify:
|
if should_notify:
|
||||||
|
@ -1058,9 +1134,21 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
await self._update_states(updates)
|
await self._update_states(updates)
|
||||||
|
|
||||||
async def set_state(
|
async def set_state(
|
||||||
self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
|
self,
|
||||||
|
target_user: UserID,
|
||||||
|
state: JsonDict,
|
||||||
|
ignore_status_msg: bool = False,
|
||||||
|
force_notify: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the presence state of the user."""
|
"""Set the presence state of the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_user: The ID of the user to set the presence state of.
|
||||||
|
state: The presence state as a JSON dictionary.
|
||||||
|
ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
|
||||||
|
If False, the user's current status will be updated.
|
||||||
|
force_notify: Whether to force notification of the update to clients.
|
||||||
|
"""
|
||||||
status_msg = state.get("status_msg", None)
|
status_msg = state.get("status_msg", None)
|
||||||
presence = state["presence"]
|
presence = state["presence"]
|
||||||
|
|
||||||
|
@ -1091,7 +1179,9 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
):
|
):
|
||||||
new_fields["last_active_ts"] = self.clock.time_msec()
|
new_fields["last_active_ts"] = self.clock.time_msec()
|
||||||
|
|
||||||
await self._update_states([prev_state.copy_and_replace(**new_fields)])
|
await self._update_states(
|
||||||
|
[prev_state.copy_and_replace(**new_fields)], force_notify=force_notify
|
||||||
|
)
|
||||||
|
|
||||||
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
|
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
|
||||||
"""Returns whether a user can see another user's presence."""
|
"""Returns whether a user can see another user's presence."""
|
||||||
|
@ -1389,11 +1479,10 @@ class PresenceEventSource:
|
||||||
#
|
#
|
||||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||||
#
|
#
|
||||||
# Same with get_module_api, get_presence_router
|
# Same with get_presence_router:
|
||||||
#
|
#
|
||||||
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
|
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
|
||||||
self.get_presence_handler = hs.get_presence_handler
|
self.get_presence_handler = hs.get_presence_handler
|
||||||
self.get_module_api = hs.get_module_api
|
|
||||||
self.get_presence_router = hs.get_presence_router
|
self.get_presence_router = hs.get_presence_router
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -1424,16 +1513,21 @@ class PresenceEventSource:
|
||||||
stream_change_cache = self.store.presence_stream_cache
|
stream_change_cache = self.store.presence_stream_cache
|
||||||
|
|
||||||
with Measure(self.clock, "presence.get_new_events"):
|
with Measure(self.clock, "presence.get_new_events"):
|
||||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
if from_key is not None:
|
||||||
|
from_key = int(from_key)
|
||||||
|
|
||||||
|
# Check if this user should receive all current, online user presence. We only
|
||||||
|
# bother to do this if from_key is set, as otherwise the user will receive all
|
||||||
|
# user presence anyways.
|
||||||
|
if await self.store.should_user_receive_full_presence_with_token(
|
||||||
|
user_id, from_key
|
||||||
|
):
|
||||||
# This user has been specified by a module to receive all current, online
|
# This user has been specified by a module to receive all current, online
|
||||||
# user presence. Removing from_key and setting include_offline to false
|
# user presence. Removing from_key and setting include_offline to false
|
||||||
# will do effectively this.
|
# will do effectively this.
|
||||||
from_key = None
|
from_key = None
|
||||||
include_offline = False
|
include_offline = False
|
||||||
|
|
||||||
if from_key is not None:
|
|
||||||
from_key = int(from_key)
|
|
||||||
|
|
||||||
max_token = self.store.get_current_presence_token()
|
max_token = self.store.get_current_presence_token()
|
||||||
if from_key == max_token:
|
if from_key == max_token:
|
||||||
# This is necessary as due to the way stream ID generators work
|
# This is necessary as due to the way stream ID generators work
|
||||||
|
@ -1467,12 +1561,6 @@ class PresenceEventSource:
|
||||||
user_id, include_offline, from_key
|
user_id, include_offline, from_key
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the user from the list of users to receive all presence
|
|
||||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
|
||||||
self.get_module_api()._send_full_presence_to_local_users.remove(
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return presence_updates, max_token
|
return presence_updates, max_token
|
||||||
|
|
||||||
# Make mypy happy. users_interested_in should now be a set
|
# Make mypy happy. users_interested_in should now be a set
|
||||||
|
@ -1522,10 +1610,6 @@ class PresenceEventSource:
|
||||||
)
|
)
|
||||||
presence_updates = list(users_to_state.values())
|
presence_updates = list(users_to_state.values())
|
||||||
|
|
||||||
# Remove the user from the list of users to receive all presence
|
|
||||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
|
||||||
self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
|
|
||||||
|
|
||||||
if not include_offline:
|
if not include_offline:
|
||||||
# Filter out offline presence states
|
# Filter out offline presence states
|
||||||
presence_updates = self._filter_offline_presence_state(presence_updates)
|
presence_updates = self._filter_offline_presence_state(presence_updates)
|
||||||
|
|
|
@ -260,24 +260,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
user_is_invited = False
|
prev_member_event = None
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
user_is_invited = prev_member_event.membership == Membership.INVITE
|
|
||||||
|
|
||||||
# If the member is not already in the room and is not accepting an invite,
|
# Check if the member should be allowed access via membership in a space.
|
||||||
# check if they should be allowed access via membership in a space.
|
await self.event_auth_handler.check_restricted_join_rules(
|
||||||
if (
|
prev_state_ids, event.room_version, user_id, prev_member_event
|
||||||
newly_joined
|
|
||||||
and not user_is_invited
|
|
||||||
and not await self.event_auth_handler.can_join_without_invite(
|
|
||||||
prev_state_ids, event.room_version, user_id
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise AuthError(
|
|
||||||
403,
|
|
||||||
"You do not belong to any of the required spaces to join this room.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||||
|
|
98
synapse/handlers/send_email.py
Normal file
98
synapse/handlers/send_email.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright 2021 The Matrix.org C.I.C. Foundation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import email.utils
|
||||||
|
import logging
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SendEmailHandler:
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
|
self._sendmail = hs.get_sendmail()
|
||||||
|
self._reactor = hs.get_reactor()
|
||||||
|
|
||||||
|
self._from = hs.config.email.email_notif_from
|
||||||
|
self._smtp_host = hs.config.email.email_smtp_host
|
||||||
|
self._smtp_port = hs.config.email.email_smtp_port
|
||||||
|
self._smtp_user = hs.config.email.email_smtp_user
|
||||||
|
self._smtp_pass = hs.config.email.email_smtp_pass
|
||||||
|
self._require_transport_security = hs.config.email.require_transport_security
|
||||||
|
|
||||||
|
async def send_email(
|
||||||
|
self,
|
||||||
|
email_address: str,
|
||||||
|
subject: str,
|
||||||
|
app_name: str,
|
||||||
|
html: str,
|
||||||
|
text: str,
|
||||||
|
) -> None:
|
||||||
|
"""Send a multipart email with the given information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email_address: The address to send the email to.
|
||||||
|
subject: The email's subject.
|
||||||
|
app_name: The app name to include in the From header.
|
||||||
|
html: The HTML content to include in the email.
|
||||||
|
text: The plain text content to include in the email.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from_string = self._from % {"app": app_name}
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
from_string = self._from
|
||||||
|
|
||||||
|
raw_from = email.utils.parseaddr(from_string)[1]
|
||||||
|
raw_to = email.utils.parseaddr(email_address)[1]
|
||||||
|
|
||||||
|
if raw_to == "":
|
||||||
|
raise RuntimeError("Invalid 'to' address")
|
||||||
|
|
||||||
|
html_part = MIMEText(html, "html", "utf8")
|
||||||
|
text_part = MIMEText(text, "plain", "utf8")
|
||||||
|
|
||||||
|
multipart_msg = MIMEMultipart("alternative")
|
||||||
|
multipart_msg["Subject"] = subject
|
||||||
|
multipart_msg["From"] = from_string
|
||||||
|
multipart_msg["To"] = email_address
|
||||||
|
multipart_msg["Date"] = email.utils.formatdate()
|
||||||
|
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||||
|
multipart_msg.attach(text_part)
|
||||||
|
multipart_msg.attach(html_part)
|
||||||
|
|
||||||
|
logger.info("Sending email to %s" % email_address)
|
||||||
|
|
||||||
|
await make_deferred_yieldable(
|
||||||
|
self._sendmail(
|
||||||
|
self._smtp_host,
|
||||||
|
raw_from,
|
||||||
|
raw_to,
|
||||||
|
multipart_msg.as_string().encode("utf8"),
|
||||||
|
reactor=self._reactor,
|
||||||
|
port=self._smtp_port,
|
||||||
|
requireAuthentication=self._smtp_user is not None,
|
||||||
|
username=self._smtp_user,
|
||||||
|
password=self._smtp_pass,
|
||||||
|
requireTransportSecurity=self._require_transport_security,
|
||||||
|
)
|
||||||
|
)
|
|
@ -16,11 +16,16 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
|
from synapse.api.constants import (
|
||||||
|
EventContentFields,
|
||||||
|
EventTypes,
|
||||||
|
HistoryVisibility,
|
||||||
|
Membership,
|
||||||
|
)
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import format_event_for_client_v2
|
from synapse.events.utils import format_event_for_client_v2
|
||||||
|
@ -32,7 +37,6 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# number of rooms to return. We'll stop once we hit this limit.
|
# number of rooms to return. We'll stop once we hit this limit.
|
||||||
# TODO: allow clients to reduce this with a request param.
|
|
||||||
MAX_ROOMS = 50
|
MAX_ROOMS = 50
|
||||||
|
|
||||||
# max number of events to return per room.
|
# max number of events to return per room.
|
||||||
|
@ -46,8 +50,7 @@ class SpaceSummaryHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._room_list_handler = hs.get_room_list_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self._state_handler = hs.get_state_handler()
|
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
@ -112,28 +115,88 @@ class SpaceSummaryHandler:
|
||||||
max_children = max_rooms_per_space if processed_rooms else None
|
max_children = max_rooms_per_space if processed_rooms else None
|
||||||
|
|
||||||
if is_in_room:
|
if is_in_room:
|
||||||
rooms, events = await self._summarize_local_room(
|
room, events = await self._summarize_local_room(
|
||||||
requester, room_id, suggested_only, max_children
|
requester, None, room_id, suggested_only, max_children
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Query of local room %s returned events %s",
|
||||||
|
room_id,
|
||||||
|
["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
|
||||||
|
)
|
||||||
|
|
||||||
|
if room:
|
||||||
|
rooms_result.append(room)
|
||||||
else:
|
else:
|
||||||
rooms, events = await self._summarize_remote_room(
|
fed_rooms, fed_events = await self._summarize_remote_room(
|
||||||
queue_entry,
|
queue_entry,
|
||||||
suggested_only,
|
suggested_only,
|
||||||
max_children,
|
max_children,
|
||||||
exclude_rooms=processed_rooms,
|
exclude_rooms=processed_rooms,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
# The results over federation might include rooms that the we,
|
||||||
"Query of %s returned rooms %s, events %s",
|
# as the requesting server, are allowed to see, but the requesting
|
||||||
queue_entry.room_id,
|
# user is not permitted see.
|
||||||
[room.get("room_id") for room in rooms],
|
#
|
||||||
["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
|
# Filter the returned results to only what is accessible to the user.
|
||||||
|
room_ids = set()
|
||||||
|
events = []
|
||||||
|
for room in fed_rooms:
|
||||||
|
fed_room_id = room.get("room_id")
|
||||||
|
if not fed_room_id or not isinstance(fed_room_id, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The room should only be included in the summary if:
|
||||||
|
# a. the user is in the room;
|
||||||
|
# b. the room is world readable; or
|
||||||
|
# c. the user is in a space that has been granted access to
|
||||||
|
# the room.
|
||||||
|
#
|
||||||
|
# Note that we know the user is not in the root room (which is
|
||||||
|
# why the remote call was made in the first place), but the user
|
||||||
|
# could be in one of the children rooms and we just didn't know
|
||||||
|
# about the link.
|
||||||
|
include_room = room.get("world_readable") is True
|
||||||
|
|
||||||
|
# Check if the user is a member of any of the allowed spaces
|
||||||
|
# from the response.
|
||||||
|
allowed_spaces = room.get("allowed_spaces")
|
||||||
|
if (
|
||||||
|
not include_room
|
||||||
|
and allowed_spaces
|
||||||
|
and isinstance(allowed_spaces, list)
|
||||||
|
):
|
||||||
|
include_room = await self._event_auth_handler.is_user_in_rooms(
|
||||||
|
allowed_spaces, requester
|
||||||
)
|
)
|
||||||
|
|
||||||
rooms_result.extend(rooms)
|
# Finally, if this isn't the requested room, check ourselves
|
||||||
|
# if we can access the room.
|
||||||
|
if not include_room and fed_room_id != queue_entry.room_id:
|
||||||
|
include_room = await self._is_room_accessible(
|
||||||
|
fed_room_id, requester, None
|
||||||
|
)
|
||||||
|
|
||||||
# any rooms returned don't need visiting again
|
# The user can see the room, include it!
|
||||||
processed_rooms.update(cast(str, room.get("room_id")) for room in rooms)
|
if include_room:
|
||||||
|
rooms_result.append(room)
|
||||||
|
room_ids.add(fed_room_id)
|
||||||
|
|
||||||
|
# All rooms returned don't need visiting again (even if the user
|
||||||
|
# didn't have access to them).
|
||||||
|
processed_rooms.add(fed_room_id)
|
||||||
|
|
||||||
|
for event in fed_events:
|
||||||
|
if event.get("room_id") in room_ids:
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Query of %s returned rooms %s, events %s",
|
||||||
|
room_id,
|
||||||
|
[room.get("room_id") for room in fed_rooms],
|
||||||
|
["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events],
|
||||||
|
)
|
||||||
|
|
||||||
# the room we queried may or may not have been returned, but don't process
|
# the room we queried may or may not have been returned, but don't process
|
||||||
# it again, anyway.
|
# it again, anyway.
|
||||||
|
@ -159,10 +222,16 @@ class SpaceSummaryHandler:
|
||||||
)
|
)
|
||||||
processed_events.add(ev_key)
|
processed_events.add(ev_key)
|
||||||
|
|
||||||
|
# Before returning to the client, remove the allowed_spaces key for any
|
||||||
|
# rooms.
|
||||||
|
for room in rooms_result:
|
||||||
|
room.pop("allowed_spaces", None)
|
||||||
|
|
||||||
return {"rooms": rooms_result, "events": events_result}
|
return {"rooms": rooms_result, "events": events_result}
|
||||||
|
|
||||||
async def federation_space_summary(
|
async def federation_space_summary(
|
||||||
self,
|
self,
|
||||||
|
origin: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
suggested_only: bool,
|
suggested_only: bool,
|
||||||
max_rooms_per_space: Optional[int],
|
max_rooms_per_space: Optional[int],
|
||||||
|
@ -172,6 +241,8 @@ class SpaceSummaryHandler:
|
||||||
Implementation of the space summary Federation API
|
Implementation of the space summary Federation API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
origin: The server requesting the spaces summary.
|
||||||
|
|
||||||
room_id: room id to start the summary at
|
room_id: room id to start the summary at
|
||||||
|
|
||||||
suggested_only: whether we should only return children with the "suggested"
|
suggested_only: whether we should only return children with the "suggested"
|
||||||
|
@ -206,13 +277,14 @@ class SpaceSummaryHandler:
|
||||||
|
|
||||||
logger.debug("Processing room %s", room_id)
|
logger.debug("Processing room %s", room_id)
|
||||||
|
|
||||||
rooms, events = await self._summarize_local_room(
|
room, events = await self._summarize_local_room(
|
||||||
None, room_id, suggested_only, max_rooms_per_space
|
None, origin, room_id, suggested_only, max_rooms_per_space
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_rooms.add(room_id)
|
processed_rooms.add(room_id)
|
||||||
|
|
||||||
rooms_result.extend(rooms)
|
if room:
|
||||||
|
rooms_result.append(room)
|
||||||
events_result.extend(events)
|
events_result.extend(events)
|
||||||
|
|
||||||
# add any children to the queue
|
# add any children to the queue
|
||||||
|
@ -223,19 +295,27 @@ class SpaceSummaryHandler:
|
||||||
async def _summarize_local_room(
|
async def _summarize_local_room(
|
||||||
self,
|
self,
|
||||||
requester: Optional[str],
|
requester: Optional[str],
|
||||||
|
origin: Optional[str],
|
||||||
room_id: str,
|
room_id: str,
|
||||||
suggested_only: bool,
|
suggested_only: bool,
|
||||||
max_children: Optional[int],
|
max_children: Optional[int],
|
||||||
) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
|
) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]:
|
||||||
"""
|
"""
|
||||||
Generate a room entry and a list of event entries for a given room.
|
Generate a room entry and a list of event entries for a given room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester: The requesting user, or None if this is over federation.
|
requester:
|
||||||
|
The user requesting the summary, if it is a local request. None
|
||||||
|
if this is a federation request.
|
||||||
|
origin:
|
||||||
|
The server requesting the summary, if it is a federation request.
|
||||||
|
None if this is a local request.
|
||||||
room_id: The room ID to summarize.
|
room_id: The room ID to summarize.
|
||||||
suggested_only: True if only suggested children should be returned.
|
suggested_only: True if only suggested children should be returned.
|
||||||
Otherwise, all children are returned.
|
Otherwise, all children are returned.
|
||||||
max_children: The maximum number of children to return for this node.
|
max_children:
|
||||||
|
The maximum number of children rooms to include. This is capped
|
||||||
|
to a server-set limit.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
|
@ -244,8 +324,8 @@ class SpaceSummaryHandler:
|
||||||
An iterable of the sorted children events. This may be limited
|
An iterable of the sorted children events. This may be limited
|
||||||
to a maximum size or may include all children.
|
to a maximum size or may include all children.
|
||||||
"""
|
"""
|
||||||
if not await self._is_room_accessible(room_id, requester):
|
if not await self._is_room_accessible(room_id, requester, origin):
|
||||||
return (), ()
|
return None, ()
|
||||||
|
|
||||||
room_entry = await self._build_room_entry(room_id)
|
room_entry = await self._build_room_entry(room_id)
|
||||||
|
|
||||||
|
@ -269,7 +349,7 @@ class SpaceSummaryHandler:
|
||||||
event_format=format_event_for_client_v2,
|
event_format=format_event_for_client_v2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return (room_entry,), events_result
|
return room_entry, events_result
|
||||||
|
|
||||||
async def _summarize_remote_room(
|
async def _summarize_remote_room(
|
||||||
self,
|
self,
|
||||||
|
@ -278,6 +358,26 @@ class SpaceSummaryHandler:
|
||||||
max_children: Optional[int],
|
max_children: Optional[int],
|
||||||
exclude_rooms: Iterable[str],
|
exclude_rooms: Iterable[str],
|
||||||
) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
|
) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
|
||||||
|
"""
|
||||||
|
Request room entries and a list of event entries for a given room by querying a remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room: The room to summarize.
|
||||||
|
suggested_only: True if only suggested children should be returned.
|
||||||
|
Otherwise, all children are returned.
|
||||||
|
max_children:
|
||||||
|
The maximum number of children rooms to include. This is capped
|
||||||
|
to a server-set limit.
|
||||||
|
exclude_rooms:
|
||||||
|
Rooms IDs which do not need to be summarized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of:
|
||||||
|
An iterable of rooms.
|
||||||
|
|
||||||
|
An iterable of the sorted children events. This may be limited
|
||||||
|
to a maximum size or may include all children.
|
||||||
|
"""
|
||||||
room_id = room.room_id
|
room_id = room.room_id
|
||||||
logger.info("Requesting summary for %s via %s", room_id, room.via)
|
logger.info("Requesting summary for %s via %s", room_id, room.via)
|
||||||
|
|
||||||
|
@ -309,27 +409,93 @@ class SpaceSummaryHandler:
|
||||||
or ev.event_type == EventTypes.SpaceChild
|
or ev.event_type == EventTypes.SpaceChild
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool:
|
async def _is_room_accessible(
|
||||||
# if we have an authenticated requesting user, first check if they are in the
|
self, room_id: str, requester: Optional[str], origin: Optional[str]
|
||||||
# room
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Calculate whether the room should be shown in the spaces summary.
|
||||||
|
|
||||||
|
It should be included if:
|
||||||
|
|
||||||
|
* The requester is joined or invited to the room.
|
||||||
|
* The requester can join without an invite (per MSC3083).
|
||||||
|
* The origin server has any user that is joined or invited to the room.
|
||||||
|
* The history visibility is set to world readable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: The room ID to summarize.
|
||||||
|
requester:
|
||||||
|
The user requesting the summary, if it is a local request. None
|
||||||
|
if this is a federation request.
|
||||||
|
origin:
|
||||||
|
The server requesting the summary, if it is a federation request.
|
||||||
|
None if this is a local request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the room should be included in the spaces summary.
|
||||||
|
"""
|
||||||
|
state_ids = await self._store.get_current_state_ids(room_id)
|
||||||
|
|
||||||
|
# If there's no state for the room, it isn't known.
|
||||||
|
if not state_ids:
|
||||||
|
logger.info("room %s is unknown, omitting from summary", room_id)
|
||||||
|
return False
|
||||||
|
|
||||||
|
room_version = await self._store.get_room_version(room_id)
|
||||||
|
|
||||||
|
# if we have an authenticated requesting user, first check if they are able to view
|
||||||
|
# stripped state in the room.
|
||||||
if requester:
|
if requester:
|
||||||
try:
|
member_event_id = state_ids.get((EventTypes.Member, requester), None)
|
||||||
await self._auth.check_user_in_room(room_id, requester)
|
|
||||||
|
# If they're in the room they can see info on it.
|
||||||
|
member_event = None
|
||||||
|
if member_event_id:
|
||||||
|
member_event = await self._store.get_event(member_event_id)
|
||||||
|
if member_event.membership in (Membership.JOIN, Membership.INVITE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Otherwise, check if they should be allowed access via membership in a space.
|
||||||
|
try:
|
||||||
|
await self._event_auth_handler.check_restricted_join_rules(
|
||||||
|
state_ids, room_version, requester, member_event
|
||||||
|
)
|
||||||
except AuthError:
|
except AuthError:
|
||||||
|
# The user doesn't have access due to spaces, but might have access
|
||||||
|
# another way. Keep trying.
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If this is a request over federation, check if the host is in the room or
|
||||||
|
# is in one of the spaces specified via the join rules.
|
||||||
|
elif origin:
|
||||||
|
if await self._auth.check_host_in_room(room_id, origin):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Alternately, if the host has a user in any of the spaces specified
|
||||||
|
# for access, then the host can see this room (and should do filtering
|
||||||
|
# if the requester cannot see it).
|
||||||
|
if await self._event_auth_handler.has_restricted_join_rules(
|
||||||
|
state_ids, room_version
|
||||||
|
):
|
||||||
|
allowed_spaces = (
|
||||||
|
await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
|
||||||
|
)
|
||||||
|
for space_id in allowed_spaces:
|
||||||
|
if await self._auth.check_host_in_room(space_id, origin):
|
||||||
|
return True
|
||||||
|
|
||||||
# otherwise, check if the room is peekable
|
# otherwise, check if the room is peekable
|
||||||
hist_vis_ev = await self._state_handler.get_current_state(
|
hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
if hist_vis_event_id:
|
||||||
)
|
hist_vis_ev = await self._store.get_event(hist_vis_event_id)
|
||||||
if hist_vis_ev:
|
|
||||||
hist_vis = hist_vis_ev.content.get("history_visibility")
|
hist_vis = hist_vis_ev.content.get("history_visibility")
|
||||||
if hist_vis == HistoryVisibility.WORLD_READABLE:
|
if hist_vis == HistoryVisibility.WORLD_READABLE:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"room %s is unpeekable and user %s is not a member, omitting from summary",
|
"room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
|
||||||
room_id,
|
room_id,
|
||||||
requester,
|
requester,
|
||||||
)
|
)
|
||||||
|
@ -354,6 +520,15 @@ class SpaceSummaryHandler:
|
||||||
if not room_type:
|
if not room_type:
|
||||||
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
|
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
|
||||||
|
|
||||||
|
room_version = await self._store.get_room_version(room_id)
|
||||||
|
allowed_spaces = None
|
||||||
|
if await self._event_auth_handler.has_restricted_join_rules(
|
||||||
|
current_state_ids, room_version
|
||||||
|
):
|
||||||
|
allowed_spaces = await self._event_auth_handler.get_spaces_that_allow_join(
|
||||||
|
current_state_ids
|
||||||
|
)
|
||||||
|
|
||||||
entry = {
|
entry = {
|
||||||
"room_id": stats["room_id"],
|
"room_id": stats["room_id"],
|
||||||
"name": stats["name"],
|
"name": stats["name"],
|
||||||
|
@ -367,6 +542,7 @@ class SpaceSummaryHandler:
|
||||||
"guest_can_join": stats["guest_access"] == "can_join",
|
"guest_can_join": stats["guest_access"] == "can_join",
|
||||||
"creation_ts": create_event.origin_server_ts,
|
"creation_ts": create_event.origin_server_ts,
|
||||||
"room_type": room_type,
|
"room_type": room_type,
|
||||||
|
"allowed_spaces": allowed_spaces,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Filter out Nones – rather omit the field altogether
|
# Filter out Nones – rather omit the field altogether
|
||||||
|
@ -430,8 +606,8 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Order may only contain characters in the range of \x20 (space) to \x7F (~).
|
# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive.
|
||||||
_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7F]")
|
_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]")
|
||||||
|
|
||||||
|
|
||||||
def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]:
|
def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]:
|
||||||
|
|
|
@ -814,7 +814,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
if self.deferred.called:
|
if self.deferred.called:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
self.stream.write(data)
|
self.stream.write(data)
|
||||||
|
except Exception:
|
||||||
|
self.deferred.errback()
|
||||||
|
return
|
||||||
|
|
||||||
self.length += len(data)
|
self.length += len(data)
|
||||||
# The first time the maximum size is exceeded, error and cancel the
|
# The first time the maximum size is exceeded, error and cancel the
|
||||||
# connection. dataReceived might be called again if data was received
|
# connection. dataReceived might be called again if data was received
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import abc
|
||||||
import cgi
|
import cgi
|
||||||
import codecs
|
import codecs
|
||||||
import logging
|
import logging
|
||||||
|
@ -19,13 +20,24 @@ import sys
|
||||||
import typing
|
import typing
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import treq
|
import treq
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
|
@ -48,6 +60,7 @@ from synapse.http.client import (
|
||||||
BlacklistingAgentWrapper,
|
BlacklistingAgentWrapper,
|
||||||
BlacklistingReactorWrapper,
|
BlacklistingReactorWrapper,
|
||||||
BodyExceededMaxSize,
|
BodyExceededMaxSize,
|
||||||
|
ByteWriteable,
|
||||||
encode_query_args,
|
encode_query_args,
|
||||||
read_body_with_max_size,
|
read_body_with_max_size,
|
||||||
)
|
)
|
||||||
|
@ -88,6 +101,27 @@ _next_id = 1
|
||||||
QueryArgs = Dict[str, Union[str, List[str]]]
|
QueryArgs = Dict[str, Union[str, List[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ByteParser(ByteWriteable, Generic[T], abc.ABC):
|
||||||
|
"""A `ByteWriteable` that has an additional `finish` function that returns
|
||||||
|
the parsed data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore
|
||||||
|
"""The expected content type of the response, e.g. `application/json`. If
|
||||||
|
the content type doesn't match we fail the request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def finish(self) -> T:
|
||||||
|
"""Called when response has finished streaming and the parser should
|
||||||
|
return the final result (or error).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class MatrixFederationRequest:
|
class MatrixFederationRequest:
|
||||||
method = attr.ib(type=str)
|
method = attr.ib(type=str)
|
||||||
|
@ -148,15 +182,32 @@ class MatrixFederationRequest:
|
||||||
return self.json
|
return self.json
|
||||||
|
|
||||||
|
|
||||||
async def _handle_json_response(
|
class JsonParser(ByteParser[Union[JsonDict, list]]):
|
||||||
|
"""A parser that buffers the response and tries to parse it as JSON."""
|
||||||
|
|
||||||
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._buffer = StringIO()
|
||||||
|
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> int:
|
||||||
|
return self._binary_wrapper.write(data)
|
||||||
|
|
||||||
|
def finish(self) -> Union[JsonDict, list]:
|
||||||
|
return json_decoder.decode(self._buffer.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_response(
|
||||||
reactor: IReactorTime,
|
reactor: IReactorTime,
|
||||||
timeout_sec: float,
|
timeout_sec: float,
|
||||||
request: MatrixFederationRequest,
|
request: MatrixFederationRequest,
|
||||||
response: IResponse,
|
response: IResponse,
|
||||||
start_ms: int,
|
start_ms: int,
|
||||||
) -> JsonDict:
|
parser: ByteParser[T],
|
||||||
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Reads the JSON body of a response, with a timeout
|
Reads the body of a response with a timeout and sends it to a parser
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reactor: twisted reactor, for the timeout
|
reactor: twisted reactor, for the timeout
|
||||||
|
@ -164,23 +215,21 @@ async def _handle_json_response(
|
||||||
request: the request that triggered the response
|
request: the request that triggered the response
|
||||||
response: response to the request
|
response: response to the request
|
||||||
start_ms: Timestamp when request was made
|
start_ms: Timestamp when request was made
|
||||||
|
parser: The parser for the response
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parsed JSON response
|
The parsed response
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
check_content_type_is_json(response.headers)
|
|
||||||
|
|
||||||
buf = StringIO()
|
try:
|
||||||
d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
|
check_content_type_is(response.headers, parser.CONTENT_TYPE)
|
||||||
|
|
||||||
|
d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE)
|
||||||
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
||||||
|
|
||||||
def parse(_len: int):
|
length = await make_deferred_yieldable(d)
|
||||||
return json_decoder.decode(buf.getvalue())
|
|
||||||
|
|
||||||
d.addCallback(parse)
|
value = parser.finish()
|
||||||
|
|
||||||
body = await make_deferred_yieldable(d)
|
|
||||||
except BodyExceededMaxSize as e:
|
except BodyExceededMaxSize as e:
|
||||||
# The response was too big.
|
# The response was too big.
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -193,9 +242,9 @@ async def _handle_json_response(
|
||||||
)
|
)
|
||||||
raise RequestSendFailed(e, can_retry=False) from e
|
raise RequestSendFailed(e, can_retry=False) from e
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# The JSON content was invalid.
|
# The content was invalid.
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] Failed to parse JSON response - %s %s",
|
"{%s} [%s] Failed to parse response - %s %s",
|
||||||
request.txn_id,
|
request.txn_id,
|
||||||
request.destination,
|
request.destination,
|
||||||
request.method,
|
request.method,
|
||||||
|
@ -225,16 +274,17 @@ async def _handle_json_response(
|
||||||
time_taken_secs = reactor.seconds() - start_ms / 1000
|
time_taken_secs = reactor.seconds() - start_ms / 1000
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
|
"{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
|
||||||
request.txn_id,
|
request.txn_id,
|
||||||
request.destination,
|
request.destination,
|
||||||
response.code,
|
response.code,
|
||||||
response.phrase.decode("ascii", errors="replace"),
|
response.phrase.decode("ascii", errors="replace"),
|
||||||
time_taken_secs,
|
time_taken_secs,
|
||||||
|
length,
|
||||||
request.method,
|
request.method,
|
||||||
request.uri.decode("ascii"),
|
request.uri.decode("ascii"),
|
||||||
)
|
)
|
||||||
return body
|
return value
|
||||||
|
|
||||||
|
|
||||||
class BinaryIOWrapper:
|
class BinaryIOWrapper:
|
||||||
|
@ -671,6 +721,7 @@ class MatrixFederationHttpClient:
|
||||||
)
|
)
|
||||||
return auth_headers
|
return auth_headers
|
||||||
|
|
||||||
|
@overload
|
||||||
async def put_json(
|
async def put_json(
|
||||||
self,
|
self,
|
||||||
destination: str,
|
destination: str,
|
||||||
|
@ -683,7 +734,41 @@ class MatrixFederationHttpClient:
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
backoff_on_404: bool = False,
|
backoff_on_404: bool = False,
|
||||||
try_trailing_slash_on_400: bool = False,
|
try_trailing_slash_on_400: bool = False,
|
||||||
|
parser: Literal[None] = None,
|
||||||
) -> Union[JsonDict, list]:
|
) -> Union[JsonDict, list]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def put_json(
|
||||||
|
self,
|
||||||
|
destination: str,
|
||||||
|
path: str,
|
||||||
|
args: Optional[QueryArgs] = None,
|
||||||
|
data: Optional[JsonDict] = None,
|
||||||
|
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||||
|
long_retries: bool = False,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
ignore_backoff: bool = False,
|
||||||
|
backoff_on_404: bool = False,
|
||||||
|
try_trailing_slash_on_400: bool = False,
|
||||||
|
parser: Optional[ByteParser[T]] = None,
|
||||||
|
) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def put_json(
|
||||||
|
self,
|
||||||
|
destination: str,
|
||||||
|
path: str,
|
||||||
|
args: Optional[QueryArgs] = None,
|
||||||
|
data: Optional[JsonDict] = None,
|
||||||
|
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||||
|
long_retries: bool = False,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
ignore_backoff: bool = False,
|
||||||
|
backoff_on_404: bool = False,
|
||||||
|
try_trailing_slash_on_400: bool = False,
|
||||||
|
parser: Optional[ByteParser] = None,
|
||||||
|
):
|
||||||
"""Sends the specified json data using PUT
|
"""Sends the specified json data using PUT
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -716,6 +801,8 @@ class MatrixFederationHttpClient:
|
||||||
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
|
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
|
||||||
will be attempted before backing off if backing off has been
|
will be attempted before backing off if backing off has been
|
||||||
enabled.
|
enabled.
|
||||||
|
parser: The parser to use to decode the response. Defaults to
|
||||||
|
parsing as JSON.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Succeeds when we get a 2xx HTTP response. The
|
Succeeds when we get a 2xx HTTP response. The
|
||||||
|
@ -756,8 +843,16 @@ class MatrixFederationHttpClient:
|
||||||
else:
|
else:
|
||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
body = await _handle_json_response(
|
if parser is None:
|
||||||
self.reactor, _sec_timeout, request, response, start_ms
|
parser = JsonParser()
|
||||||
|
|
||||||
|
body = await _handle_response(
|
||||||
|
self.reactor,
|
||||||
|
_sec_timeout,
|
||||||
|
request,
|
||||||
|
response,
|
||||||
|
start_ms,
|
||||||
|
parser=parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
return body
|
return body
|
||||||
|
@ -830,12 +925,8 @@ class MatrixFederationHttpClient:
|
||||||
else:
|
else:
|
||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
body = await _handle_json_response(
|
body = await _handle_response(
|
||||||
self.reactor,
|
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||||
_sec_timeout,
|
|
||||||
request,
|
|
||||||
response,
|
|
||||||
start_ms,
|
|
||||||
)
|
)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
@ -907,8 +998,8 @@ class MatrixFederationHttpClient:
|
||||||
else:
|
else:
|
||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
body = await _handle_json_response(
|
body = await _handle_response(
|
||||||
self.reactor, _sec_timeout, request, response, start_ms
|
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||||
)
|
)
|
||||||
|
|
||||||
return body
|
return body
|
||||||
|
@ -975,8 +1066,8 @@ class MatrixFederationHttpClient:
|
||||||
else:
|
else:
|
||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
body = await _handle_json_response(
|
body = await _handle_response(
|
||||||
self.reactor, _sec_timeout, request, response, start_ms
|
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||||
)
|
)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e):
|
||||||
return repr(e)
|
return repr(e)
|
||||||
|
|
||||||
|
|
||||||
def check_content_type_is_json(headers: Headers) -> None:
|
def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
|
||||||
"""
|
"""
|
||||||
Check that a set of HTTP headers have a Content-Type header, and that it
|
Check that a set of HTTP headers have a Content-Type header, and that it
|
||||||
is application/json.
|
is the expected value..
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
headers: headers to check
|
headers: headers to check
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RequestSendFailed: if the Content-Type header is missing or isn't JSON
|
RequestSendFailed: if the Content-Type header is missing or doesn't match
|
||||||
|
|
||||||
"""
|
"""
|
||||||
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
||||||
|
@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None:
|
||||||
|
|
||||||
c_type = content_type_headers[0].decode("ascii") # only the first header
|
c_type = content_type_headers[0].decode("ascii") # only the first header
|
||||||
val, options = cgi.parse_header(c_type)
|
val, options = cgi.parse_header(c_type)
|
||||||
if val != "application/json":
|
if val != expected_content_type:
|
||||||
raise RequestSendFailed(
|
raise RequestSendFailed(
|
||||||
RuntimeError(
|
RuntimeError(
|
||||||
"Remote server sent Content-Type header of '%s', not 'application/json'"
|
f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
|
||||||
% c_type,
|
|
||||||
),
|
),
|
||||||
can_retry=False,
|
can_retry=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -56,14 +56,6 @@ class ModuleApi:
|
||||||
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
||||||
self._public_room_list_manager = PublicRoomListManager(hs)
|
self._public_room_list_manager = PublicRoomListManager(hs)
|
||||||
|
|
||||||
# The next time these users sync, they will receive the current presence
|
|
||||||
# state of all local users. Users are added by send_local_online_presence_to,
|
|
||||||
# and removed after a successful sync.
|
|
||||||
#
|
|
||||||
# We make this a private variable to deter modules from accessing it directly,
|
|
||||||
# though other classes in Synapse will still do so.
|
|
||||||
self._send_full_presence_to_local_users = set()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def http_client(self):
|
def http_client(self):
|
||||||
"""Allows making outbound HTTP requests to remote resources.
|
"""Allows making outbound HTTP requests to remote resources.
|
||||||
|
@ -405,38 +397,43 @@ class ModuleApi:
|
||||||
Updates to remote users will be sent immediately, whereas local users will receive
|
Updates to remote users will be sent immediately, whereas local users will receive
|
||||||
them on their next sync attempt.
|
them on their next sync attempt.
|
||||||
|
|
||||||
Note that this method can only be run on the main or federation_sender worker
|
Note that this method can only be run on the process that is configured to write to the
|
||||||
processes.
|
presence stream. By default this is the main process.
|
||||||
"""
|
"""
|
||||||
if not self._hs.should_send_federation():
|
if self._hs._instance_name not in self._hs.config.worker.writers.presence:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"send_local_online_presence_to can only be run "
|
"send_local_online_presence_to can only be run "
|
||||||
"on processes that send federation",
|
"on the process that is configured to write to the "
|
||||||
|
"presence stream (by default this is the main process)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_users = set()
|
||||||
|
remote_users = set()
|
||||||
for user in users:
|
for user in users:
|
||||||
if self._hs.is_mine_id(user):
|
if self._hs.is_mine_id(user):
|
||||||
# Modify SyncHandler._generate_sync_entry_for_presence to call
|
local_users.add(user)
|
||||||
# presence_source.get_new_events with an empty `from_key` if
|
|
||||||
# that user's ID were in a list modified by ModuleApi somewhere.
|
|
||||||
# That user would then get all presence state on next incremental sync.
|
|
||||||
|
|
||||||
# Force a presence initial_sync for this user next time
|
|
||||||
self._send_full_presence_to_local_users.add(user)
|
|
||||||
else:
|
else:
|
||||||
|
remote_users.add(user)
|
||||||
|
|
||||||
|
# We pull out the presence handler here to break a cyclic
|
||||||
|
# dependency between the presence router and module API.
|
||||||
|
presence_handler = self._hs.get_presence_handler()
|
||||||
|
|
||||||
|
if local_users:
|
||||||
|
# Force a presence initial_sync for these users next time they sync.
|
||||||
|
await presence_handler.send_full_presence_to_users(local_users)
|
||||||
|
|
||||||
|
for user in remote_users:
|
||||||
# Retrieve presence state for currently online users that this user
|
# Retrieve presence state for currently online users that this user
|
||||||
# is considered interested in
|
# is considered interested in.
|
||||||
presence_events, _ = await self._presence_stream.get_new_events(
|
presence_events, _ = await self._presence_stream.get_new_events(
|
||||||
UserID.from_string(user), from_key=None, include_offline=False
|
UserID.from_string(user), from_key=None, include_offline=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send to remote destinations.
|
# Send to remote destinations.
|
||||||
|
destination = UserID.from_string(user).domain
|
||||||
# We pull out the presence handler here to break a cyclic
|
presence_handler.get_federation_queue().send_presence_to_destinations(
|
||||||
# dependency between the presence router and module API.
|
presence_events, destination
|
||||||
presence_handler = self._hs.get_presence_handler()
|
|
||||||
await presence_handler.maybe_send_presence_to_interested_destinations(
|
|
||||||
presence_events
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,12 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import email.mime.multipart
|
|
||||||
import email.utils
|
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from email.mime.multipart import MIMEMultipart
|
|
||||||
from email.mime.text import MIMEText
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
|
||||||
|
|
||||||
import bleach
|
import bleach
|
||||||
|
@ -27,7 +23,6 @@ from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.config.emailconfig import EmailSubjectConfig
|
from synapse.config.emailconfig import EmailSubjectConfig
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
|
||||||
from synapse.push.presentable_names import (
|
from synapse.push.presentable_names import (
|
||||||
calculate_room_name,
|
calculate_room_name,
|
||||||
descriptor_from_member_events,
|
descriptor_from_member_events,
|
||||||
|
@ -108,7 +103,7 @@ class Mailer:
|
||||||
self.template_html = template_html
|
self.template_html = template_html
|
||||||
self.template_text = template_text
|
self.template_text = template_text
|
||||||
|
|
||||||
self.sendmail = self.hs.get_sendmail()
|
self.send_email_handler = hs.get_send_email_handler()
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.state_store = self.hs.get_storage().state
|
self.state_store = self.hs.get_storage().state
|
||||||
self.macaroon_gen = self.hs.get_macaroon_generator()
|
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||||
|
@ -310,17 +305,6 @@ class Mailer:
|
||||||
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
|
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send an email with the given information and template text"""
|
"""Send an email with the given information and template text"""
|
||||||
try:
|
|
||||||
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
|
|
||||||
except TypeError:
|
|
||||||
from_string = self.hs.config.email_notif_from
|
|
||||||
|
|
||||||
raw_from = email.utils.parseaddr(from_string)[1]
|
|
||||||
raw_to = email.utils.parseaddr(email_address)[1]
|
|
||||||
|
|
||||||
if raw_to == "":
|
|
||||||
raise RuntimeError("Invalid 'to' address")
|
|
||||||
|
|
||||||
template_vars = {
|
template_vars = {
|
||||||
"app_name": self.app_name,
|
"app_name": self.app_name,
|
||||||
"server_name": self.hs.config.server.server_name,
|
"server_name": self.hs.config.server.server_name,
|
||||||
|
@ -329,35 +313,14 @@ class Mailer:
|
||||||
template_vars.update(extra_template_vars)
|
template_vars.update(extra_template_vars)
|
||||||
|
|
||||||
html_text = self.template_html.render(**template_vars)
|
html_text = self.template_html.render(**template_vars)
|
||||||
html_part = MIMEText(html_text, "html", "utf8")
|
|
||||||
|
|
||||||
plain_text = self.template_text.render(**template_vars)
|
plain_text = self.template_text.render(**template_vars)
|
||||||
text_part = MIMEText(plain_text, "plain", "utf8")
|
|
||||||
|
|
||||||
multipart_msg = MIMEMultipart("alternative")
|
await self.send_email_handler.send_email(
|
||||||
multipart_msg["Subject"] = subject
|
email_address=email_address,
|
||||||
multipart_msg["From"] = from_string
|
subject=subject,
|
||||||
multipart_msg["To"] = email_address
|
app_name=self.app_name,
|
||||||
multipart_msg["Date"] = email.utils.formatdate()
|
html=html_text,
|
||||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
text=plain_text,
|
||||||
multipart_msg.attach(text_part)
|
|
||||||
multipart_msg.attach(html_part)
|
|
||||||
|
|
||||||
logger.info("Sending email to %s" % email_address)
|
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
|
||||||
self.sendmail(
|
|
||||||
self.hs.config.email_smtp_host,
|
|
||||||
raw_from,
|
|
||||||
raw_to,
|
|
||||||
multipart_msg.as_string().encode("utf8"),
|
|
||||||
reactor=self.hs.get_reactor(),
|
|
||||||
port=self.hs.config.email_smtp_port,
|
|
||||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
|
||||||
username=self.hs.config.email_smtp_user,
|
|
||||||
password=self.hs.config.email_smtp_pass,
|
|
||||||
requireTransportSecurity=self.hs.config.require_transport_security,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_room_vars(
|
async def _get_room_vars(
|
||||||
|
|
|
@ -87,6 +87,7 @@ REQUIREMENTS = [
|
||||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||||
# with the latest security patches.
|
# with the latest security patches.
|
||||||
"cryptography>=3.4.7",
|
"cryptography>=3.4.7",
|
||||||
|
"ijson>=3.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
|
|
|
@ -73,6 +73,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||||
{
|
{
|
||||||
"state": { ... },
|
"state": { ... },
|
||||||
"ignore_status_msg": false,
|
"ignore_status_msg": false,
|
||||||
|
"force_notify": false
|
||||||
}
|
}
|
||||||
|
|
||||||
200 OK
|
200 OK
|
||||||
|
@ -91,17 +92,23 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||||
self._presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, state, ignore_status_msg=False):
|
async def _serialize_payload(
|
||||||
|
user_id, state, ignore_status_msg=False, force_notify=False
|
||||||
|
):
|
||||||
return {
|
return {
|
||||||
"state": state,
|
"state": state,
|
||||||
"ignore_status_msg": ignore_status_msg,
|
"ignore_status_msg": ignore_status_msg,
|
||||||
|
"force_notify": force_notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
await self._presence_handler.set_state(
|
await self._presence_handler.set_state(
|
||||||
UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
|
UserID.from_string(user_id),
|
||||||
|
content["state"],
|
||||||
|
content["ignore_status_msg"],
|
||||||
|
content["force_notify"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -24,7 +24,7 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.client_ip_last_seen = LruCache(
|
self.client_ip_last_seen = LruCache(
|
||||||
cache_name="client_ip_last_seen", keylen=4, max_size=50000
|
cache_name="client_ip_last_seen", max_size=50000
|
||||||
) # type: LruCache[tuple, int]
|
) # type: LruCache[tuple, int]
|
||||||
|
|
||||||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from synapse.storage.databases.main.transactions import TransactionStore
|
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
|
|
||||||
pass
|
|
|
@ -54,7 +54,6 @@ class SendServerNoticeServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.txns = HttpTransactionCache(hs)
|
self.txns = HttpTransactionCache(hs)
|
||||||
self.snm = hs.get_server_notices_manager()
|
|
||||||
|
|
||||||
def register(self, json_resource: HttpServer):
|
def register(self, json_resource: HttpServer):
|
||||||
PATTERN = "/send_server_notice"
|
PATTERN = "/send_server_notice"
|
||||||
|
@ -77,7 +76,10 @@ class SendServerNoticeServlet(RestServlet):
|
||||||
event_type = body.get("type", EventTypes.Message)
|
event_type = body.get("type", EventTypes.Message)
|
||||||
state_key = body.get("state_key")
|
state_key = body.get("state_key")
|
||||||
|
|
||||||
if not self.snm.is_enabled():
|
# We grab the server notices manager here as its initialisation has a check for worker processes,
|
||||||
|
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
|
||||||
|
# admin api).
|
||||||
|
if not self.hs.get_server_notices_manager().is_enabled():
|
||||||
raise SynapseError(400, "Server notices are not enabled on this server")
|
raise SynapseError(400, "Server notices are not enabled on this server")
|
||||||
|
|
||||||
user_id = body["user_id"]
|
user_id = body["user_id"]
|
||||||
|
@ -85,7 +87,7 @@ class SendServerNoticeServlet(RestServlet):
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(user_id):
|
||||||
raise SynapseError(400, "Server notices can only be sent to local users")
|
raise SynapseError(400, "Server notices can only be sent to local users")
|
||||||
|
|
||||||
event = await self.snm.send_notice(
|
event = await self.hs.get_server_notices_manager().send_notice(
|
||||||
user_id=body["user_id"],
|
user_id=body["user_id"],
|
||||||
type=event_type,
|
type=event_type,
|
||||||
state_key=state_key,
|
state_key=state_key,
|
||||||
|
|
|
@ -48,11 +48,6 @@ class LocalKey(Resource):
|
||||||
"key": # base64 encoded NACL verification key.
|
"key": # base64 encoded NACL verification key.
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tls_fingerprints": [ # Fingerprints of the TLS certs this server uses.
|
|
||||||
{
|
|
||||||
"sha256": # base64 encoded sha256 fingerprint of the X509 cert
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"signatures": {
|
"signatures": {
|
||||||
"this.server.example.com": {
|
"this.server.example.com": {
|
||||||
"algorithm:version": # NACL signature for this server
|
"algorithm:version": # NACL signature for this server
|
||||||
|
@ -89,14 +84,11 @@ class LocalKey(Resource):
|
||||||
"expired_ts": key.expired_ts,
|
"expired_ts": key.expired_ts,
|
||||||
}
|
}
|
||||||
|
|
||||||
tls_fingerprints = self.config.tls_fingerprints
|
|
||||||
|
|
||||||
json_object = {
|
json_object = {
|
||||||
"valid_until_ts": self.valid_until_ts,
|
"valid_until_ts": self.valid_until_ts,
|
||||||
"server_name": self.config.server_name,
|
"server_name": self.config.server_name,
|
||||||
"verify_keys": verify_keys,
|
"verify_keys": verify_keys,
|
||||||
"old_verify_keys": old_verify_keys,
|
"old_verify_keys": old_verify_keys,
|
||||||
"tls_fingerprints": tls_fingerprints,
|
|
||||||
}
|
}
|
||||||
for key in self.config.signing_key:
|
for key in self.config.signing_key:
|
||||||
json_object = sign_json(json_object, self.config.server_name, key)
|
json_object = sign_json(json_object, self.config.server_name, key)
|
||||||
|
|
|
@ -73,9 +73,6 @@ class RemoteKey(DirectServeJsonResource):
|
||||||
"expired_ts": 0, # when the key stop being used.
|
"expired_ts": 0, # when the key stop being used.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"tls_fingerprints": [
|
|
||||||
{ "sha256": # fingerprint }
|
|
||||||
]
|
|
||||||
"signatures": {
|
"signatures": {
|
||||||
"remote.server.example.com": {...}
|
"remote.server.example.com": {...}
|
||||||
"this.server.example.com": {...}
|
"this.server.example.com": {...}
|
||||||
|
|
|
@ -76,6 +76,8 @@ class MediaRepository:
|
||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.max_image_pixels = hs.config.max_image_pixels
|
self.max_image_pixels = hs.config.max_image_pixels
|
||||||
|
|
||||||
|
Thumbnailer.set_limits(self.max_image_pixels)
|
||||||
|
|
||||||
self.primary_base_path = hs.config.media_store_path # type: str
|
self.primary_base_path = hs.config.media_store_path # type: str
|
||||||
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,10 @@ class Thumbnailer:
|
||||||
|
|
||||||
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG", "image/webp": "WEBP"}
|
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG", "image/webp": "WEBP"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_limits(max_image_pixels: int):
|
||||||
|
Image.MAX_IMAGE_PIXELS = max_image_pixels
|
||||||
|
|
||||||
def __init__(self, input_path: str):
|
def __init__(self, input_path: str):
|
||||||
try:
|
try:
|
||||||
self.image = Image.open(input_path)
|
self.image = Image.open(input_path)
|
||||||
|
@ -47,6 +51,11 @@ class Thumbnailer:
|
||||||
# If an error occurs opening the image, a thumbnail won't be able to
|
# If an error occurs opening the image, a thumbnail won't be able to
|
||||||
# be generated.
|
# be generated.
|
||||||
raise ThumbnailError from e
|
raise ThumbnailError from e
|
||||||
|
except Image.DecompressionBombError as e:
|
||||||
|
# If an image decompression bomb error occurs opening the image,
|
||||||
|
# then the image exceeds the pixel limit and a thumbnail won't
|
||||||
|
# be able to be generated.
|
||||||
|
raise ThumbnailError from e
|
||||||
|
|
||||||
self.width, self.height = self.image.size
|
self.width, self.height = self.image.size
|
||||||
self.transpose_method = None
|
self.transpose_method = None
|
||||||
|
|
|
@ -104,6 +104,7 @@ from synapse.handlers.room_list import RoomListHandler
|
||||||
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
|
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
|
||||||
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||||
from synapse.handlers.search import SearchHandler
|
from synapse.handlers.search import SearchHandler
|
||||||
|
from synapse.handlers.send_email import SendEmailHandler
|
||||||
from synapse.handlers.set_password import SetPasswordHandler
|
from synapse.handlers.set_password import SetPasswordHandler
|
||||||
from synapse.handlers.space_summary import SpaceSummaryHandler
|
from synapse.handlers.space_summary import SpaceSummaryHandler
|
||||||
from synapse.handlers.sso import SsoHandler
|
from synapse.handlers.sso import SsoHandler
|
||||||
|
@ -549,6 +550,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_search_handler(self) -> SearchHandler:
|
def get_search_handler(self) -> SearchHandler:
|
||||||
return SearchHandler(self)
|
return SearchHandler(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_send_email_handler(self) -> SendEmailHandler:
|
||||||
|
return SendEmailHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_set_password_handler(self) -> SetPasswordHandler:
|
def get_set_password_handler(self) -> SetPasswordHandler:
|
||||||
return SetPasswordHandler(self)
|
return SetPasswordHandler(self)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
from abc import ABCMeta
|
from abc import ABCMeta
|
||||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
|
||||||
|
|
||||||
|
@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.database_engine = database.engine
|
self.database_engine = database.engine
|
||||||
self.db_pool = database
|
self.db_pool = database
|
||||||
self.rand = random.SystemRandom()
|
|
||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -67,7 +67,7 @@ from .state import StateStore
|
||||||
from .stats import StatsStore
|
from .stats import StatsStore
|
||||||
from .stream import StreamStore
|
from .stream import StreamStore
|
||||||
from .tags import TagsStore
|
from .tags import TagsStore
|
||||||
from .transactions import TransactionStore
|
from .transactions import TransactionWorkerStore
|
||||||
from .ui_auth import UIAuthStore
|
from .ui_auth import UIAuthStore
|
||||||
from .user_directory import UserDirectoryStore
|
from .user_directory import UserDirectoryStore
|
||||||
from .user_erasure_store import UserErasureStore
|
from .user_erasure_store import UserErasureStore
|
||||||
|
@ -83,7 +83,7 @@ class DataStore(
|
||||||
StreamStore,
|
StreamStore,
|
||||||
ProfileStore,
|
ProfileStore,
|
||||||
PresenceStore,
|
PresenceStore,
|
||||||
TransactionStore,
|
TransactionWorkerStore,
|
||||||
DirectoryStore,
|
DirectoryStore,
|
||||||
KeyStore,
|
KeyStore,
|
||||||
StateStore,
|
StateStore,
|
||||||
|
|
|
@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
|
|
||||||
self.client_ip_last_seen = LruCache(
|
self.client_ip_last_seen = LruCache(
|
||||||
cache_name="client_ip_last_seen", keylen=4, max_size=50000
|
cache_name="client_ip_last_seen", max_size=50000
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
|
@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
)
|
)
|
||||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
|
@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||||
# the device exists.
|
# the device exists.
|
||||||
self.device_id_exists_cache = LruCache(
|
self.device_id_exists_cache = LruCache(
|
||||||
cache_name="device_id_exists", keylen=2, max_size=10000
|
cache_name="device_id_exists", max_size=10000
|
||||||
)
|
)
|
||||||
|
|
||||||
async def store_device(
|
async def store_device(
|
||||||
|
|
|
@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||||
num_args=1,
|
num_args=1,
|
||||||
)
|
)
|
||||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||||
self, user_ids: List[str]
|
self, user_ids: Iterable[str]
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
|
@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||||
self,
|
self,
|
||||||
txn: Connection,
|
txn: Connection,
|
||||||
user_ids: List[str],
|
user_ids: Iterable[str],
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
|
|
|
@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
self._get_event_cache = LruCache(
|
self._get_event_cache = LruCache(
|
||||||
cache_name="*getEvent*",
|
cache_name="*getEvent*",
|
||||||
keylen=3,
|
|
||||||
max_size=hs.config.caches.event_cache_size,
|
max_size=hs.config.caches.event_cache_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
keys = {}
|
keys = {}
|
||||||
|
|
||||||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
|
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
|
||||||
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
||||||
|
|
||||||
# batch_iter always returns tuples so it's safe to do len(batch)
|
# batch_iter always returns tuples so it's safe to do len(batch)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
from synapse.api.presence import PresenceState, UserPresenceState
|
from synapse.api.presence import PresenceState, UserPresenceState
|
||||||
from synapse.replication.tcp.streams import PresenceStream
|
from synapse.replication.tcp.streams import PresenceStream
|
||||||
|
@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
db_conn, "presence_stream", "stream_id"
|
db_conn, "presence_stream", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.hs = hs
|
||||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||||
|
|
||||||
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
|
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
|
||||||
|
@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
|
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
|
||||||
|
|
||||||
|
# Delete old rows to stop database from getting really big
|
||||||
|
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
|
||||||
|
|
||||||
|
for states in batch_iter(presence_states, 50):
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
self.database_engine, "user_id", [s.user_id for s in states]
|
||||||
|
)
|
||||||
|
txn.execute(sql + clause, [stream_id] + list(args))
|
||||||
|
|
||||||
# Actually insert new rows
|
# Actually insert new rows
|
||||||
self.db_pool.simple_insert_many_txn(
|
self.db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete old rows to stop database from getting really big
|
|
||||||
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
|
|
||||||
|
|
||||||
for states in batch_iter(presence_states, 50):
|
|
||||||
clause, args = make_in_list_sql_clause(
|
|
||||||
self.database_engine, "user_id", [s.user_id for s in states]
|
|
||||||
)
|
|
||||||
txn.execute(sql + clause, [stream_id] + list(args))
|
|
||||||
|
|
||||||
async def get_all_presence_updates(
|
async def get_all_presence_updates(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||||
|
@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore):
|
||||||
|
|
||||||
return {row["user_id"]: UserPresenceState(**row) for row in rows}
|
return {row["user_id"]: UserPresenceState(**row) for row in rows}
|
||||||
|
|
||||||
|
async def should_user_receive_full_presence_with_token(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
from_token: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given user should receive full presence using the stream token
|
||||||
|
they're updating from.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user to check.
|
||||||
|
from_token: The stream token included in their /sync token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the user should have full presence sent to them, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _should_user_receive_full_presence_with_token_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT 1 FROM users_to_send_full_presence_to
|
||||||
|
WHERE user_id = ?
|
||||||
|
AND presence_stream_id >= ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (user_id, from_token))
|
||||||
|
return bool(txn.fetchone())
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"should_user_receive_full_presence_with_token",
|
||||||
|
_should_user_receive_full_presence_with_token_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
|
||||||
|
"""Adds to the list of users who should receive a full snapshot of presence
|
||||||
|
upon their next sync.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids: An iterable of user IDs.
|
||||||
|
"""
|
||||||
|
# Add user entries to the table, updating the presence_stream_id column if the user already
|
||||||
|
# exists in the table.
|
||||||
|
await self.db_pool.simple_upsert_many(
|
||||||
|
table="users_to_send_full_presence_to",
|
||||||
|
key_names=("user_id",),
|
||||||
|
key_values=[(user_id,) for user_id in user_ids],
|
||||||
|
value_names=("presence_stream_id",),
|
||||||
|
# We save the current presence stream ID token along with the user ID entry so
|
||||||
|
# that when a user /sync's, even if they syncing multiple times across separate
|
||||||
|
# devices at different times, each device will receive full presence once - when
|
||||||
|
# the presence stream ID in their sync token is less than the one in the table
|
||||||
|
# for their user ID.
|
||||||
|
value_values=(
|
||||||
|
(self._presence_id_gen.get_current_token(),) for _ in user_ids
|
||||||
|
),
|
||||||
|
desc="add_users_to_send_full_presence_to",
|
||||||
|
)
|
||||||
|
|
||||||
async def get_presence_for_all_users(
|
async def get_presence_for_all_users(
|
||||||
self,
|
self,
|
||||||
include_offline: bool = True,
|
include_offline: bool = True,
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -997,7 +998,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
expiration_ts = now_ms + self._account_validity_period
|
expiration_ts = now_ms + self._account_validity_period
|
||||||
|
|
||||||
if use_delta:
|
if use_delta:
|
||||||
expiration_ts = self.rand.randrange(
|
expiration_ts = random.randrange(
|
||||||
expiration_ts - self._account_validity_startup_job_max_delta,
|
expiration_ts - self._account_validity_startup_job_max_delta,
|
||||||
expiration_ts,
|
expiration_ts,
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,13 +16,15 @@ import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import db_to_json
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
db_binary_type = memoryview
|
db_binary_type = memoryview
|
||||||
|
|
||||||
|
@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
|
||||||
"_TransactionRow", ("response_code", "response_json")
|
"_TransactionRow", ("response_code", "response_json")
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL = object()
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class DestinationRetryTimings:
|
||||||
|
"""The current destination retry timing info for a remote server."""
|
||||||
|
|
||||||
|
# The first time we tried and failed to reach the remote server, in ms.
|
||||||
|
failure_ts: int
|
||||||
|
|
||||||
|
# The last time we tried and failed to reach the remote server, in ms.
|
||||||
|
retry_last_ts: int
|
||||||
|
|
||||||
|
# How long since the last time we tried to reach the remote server before
|
||||||
|
# trying again, in ms.
|
||||||
|
retry_interval: int
|
||||||
|
|
||||||
|
|
||||||
class TransactionWorkerStore(SQLBaseStore):
|
class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
|
||||||
"_cleanup_transactions", _cleanup_transactions_txn
|
"_cleanup_transactions", _cleanup_transactions_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransactionStore(TransactionWorkerStore):
|
|
||||||
"""A collection of queries for handling PDUs."""
|
|
||||||
|
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
|
||||||
super().__init__(database, db_conn, hs)
|
|
||||||
|
|
||||||
self._destination_retry_cache = ExpiringCache(
|
|
||||||
cache_name="get_destination_retry_timings",
|
|
||||||
clock=self._clock,
|
|
||||||
expiry_ms=5 * 60 * 1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_received_txn_response(
|
async def get_received_txn_response(
|
||||||
self, transaction_id: str, origin: str
|
self, transaction_id: str, origin: str
|
||||||
) -> Optional[Tuple[int, JsonDict]]:
|
) -> Optional[Tuple[int, JsonDict]]:
|
||||||
|
@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
|
||||||
desc="set_received_txn_response",
|
desc="set_received_txn_response",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_destination_retry_timings(self, destination):
|
@cached(max_entries=10000)
|
||||||
|
async def get_destination_retry_timings(
|
||||||
|
self,
|
||||||
|
destination: str,
|
||||||
|
) -> Optional[DestinationRetryTimings]:
|
||||||
"""Gets the current retry timings (if any) for a given destination.
|
"""Gets the current retry timings (if any) for a given destination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
|
||||||
Otherwise a dict for the retry scheme
|
Otherwise a dict for the retry scheme
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = self._destination_retry_cache.get(destination, SENTINEL)
|
|
||||||
if result is not SENTINEL:
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await self.db_pool.runInteraction(
|
result = await self.db_pool.runInteraction(
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
self._get_destination_retry_timings,
|
self._get_destination_retry_timings,
|
||||||
destination,
|
destination,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We don't hugely care about race conditions between getting and
|
|
||||||
# invalidating the cache, since we time out fairly quickly anyway.
|
|
||||||
self._destination_retry_cache[destination] = result
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_destination_retry_timings(self, txn, destination):
|
def _get_destination_retry_timings(
|
||||||
|
self, txn, destination: str
|
||||||
|
) -> Optional[DestinationRetryTimings]:
|
||||||
result = self.db_pool.simple_select_one_txn(
|
result = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="destinations",
|
table="destinations",
|
||||||
keyvalues={"destination": destination},
|
keyvalues={"destination": destination},
|
||||||
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
|
retcols=("failure_ts", "retry_last_ts", "retry_interval"),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check we have a row and retry_last_ts is not null or zero
|
# check we have a row and retry_last_ts is not null or zero
|
||||||
# (retry_last_ts can't be negative)
|
# (retry_last_ts can't be negative)
|
||||||
if result and result["retry_last_ts"]:
|
if result and result["retry_last_ts"]:
|
||||||
return result
|
return DestinationRetryTimings(**result)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
|
||||||
retry_interval: how long until next retry in ms
|
retry_interval: how long until next retry in ms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._destination_retry_cache.pop(destination, None)
|
|
||||||
if self.database_engine.can_native_upsert:
|
if self.database_engine.can_native_upsert:
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"set_destination_retry_timings",
|
"set_destination_retry_timings",
|
||||||
|
@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
|
||||||
|
|
||||||
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
|
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_destination_retry_timings, (destination,)
|
||||||
|
)
|
||||||
|
|
||||||
def _set_destination_retry_timings_emulated(
|
def _set_destination_retry_timings_emulated(
|
||||||
self, txn, destination, failure_ts, retry_last_ts, retry_interval
|
self, txn, destination, failure_ts, retry_last_ts, retry_interval
|
||||||
):
|
):
|
||||||
|
@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_destination_retry_timings, (destination,)
|
||||||
|
)
|
||||||
|
|
||||||
async def store_destination_rooms_entries(
|
async def store_destination_rooms_entries(
|
||||||
self,
|
self,
|
||||||
destinations: Iterable[str],
|
destinations: Iterable[str],
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, Iterable
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||||
return bool(result)
|
return bool(result)
|
||||||
|
|
||||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||||
async def are_users_erased(self, user_ids):
|
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
Checks which users in a list have requested erasure
|
Checks which users in a list have requested erasure
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_ids (iterable[str]): full user id to check
|
user_ids: full user ids to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, bool]:
|
|
||||||
for each user, whether the user has requested erasure.
|
for each user, whether the user has requested erasure.
|
||||||
"""
|
"""
|
||||||
# this serves the dual purpose of (a) making sure we can do len and
|
|
||||||
# iterate it multiple times, and (b) avoiding duplicates.
|
|
||||||
user_ids = tuple(set(user_ids))
|
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="erased_users",
|
table="erased_users",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Add a table that keeps track of a list of users who should, upon their next
|
||||||
|
-- sync request, receive presence for all currently online users that they are
|
||||||
|
-- "interested" in.
|
||||||
|
|
||||||
|
-- The motivation for a DB table over an in-memory list is so that this list
|
||||||
|
-- can be added to and retrieved from by any worker. Specifically, we don't
|
||||||
|
-- want to duplicate work across multiple sync workers.
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to(
|
||||||
|
-- The user ID to send full presence to.
|
||||||
|
user_id TEXT PRIMARY KEY,
|
||||||
|
-- A presence stream ID token - the current presence stream token when the row was last upserted.
|
||||||
|
-- If a user calls /sync and this token is part of the update they're to receive, we also include
|
||||||
|
-- full user presence in the response.
|
||||||
|
-- This allows multiple devices for a user to receive full presence whenever they next call /sync.
|
||||||
|
presence_stream_id BIGINT,
|
||||||
|
FOREIGN KEY (user_id)
|
||||||
|
REFERENCES users (name)
|
||||||
|
);
|
|
@ -540,7 +540,7 @@ class StateGroupStorage:
|
||||||
state_filter: The state filter used to fetch state from the database.
|
state_filter: The state filter used to fetch state from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict from (type, state_key) -> state_event
|
A dict from (type, state_key) -> state_event_id
|
||||||
"""
|
"""
|
||||||
state_map = await self.get_state_ids_for_events(
|
state_map = await self.get_state_ids_for_events(
|
||||||
[event_id], state_filter or StateFilter.all()
|
[event_id], state_filter or StateFilter.all()
|
||||||
|
|
153
synapse/util/batching_queue.py
Normal file
153
synapse/util/batching_queue.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Hashable,
|
||||||
|
List,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
|
from synapse.metrics import LaterGauge
|
||||||
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
V = TypeVar("V")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchingQueue(Generic[V, R]):
|
||||||
|
"""A queue that batches up work, calling the provided processing function
|
||||||
|
with all pending work (for a given key).
|
||||||
|
|
||||||
|
The provided processing function will only be called once at a time for each
|
||||||
|
key. It will be called the next reactor tick after `add_to_queue` has been
|
||||||
|
called, and will keep being called until the queue has been drained (for the
|
||||||
|
given key).
|
||||||
|
|
||||||
|
Note that the return value of `add_to_queue` will be the return value of the
|
||||||
|
processing function that processed the given item. This means that the
|
||||||
|
returned value will likely include data for other items that were in the
|
||||||
|
batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
clock: Clock,
|
||||||
|
process_batch_callback: Callable[[List[V]], Awaitable[R]],
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._clock = clock
|
||||||
|
|
||||||
|
# The set of keys currently being processed.
|
||||||
|
self._processing_keys = set() # type: Set[Hashable]
|
||||||
|
|
||||||
|
# The currently pending batch of values by key, with a Deferred to call
|
||||||
|
# with the result of the corresponding `_process_batch_callback` call.
|
||||||
|
self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
|
||||||
|
|
||||||
|
# The function to call with batches of values.
|
||||||
|
self._process_batch_callback = process_batch_callback
|
||||||
|
|
||||||
|
LaterGauge(
|
||||||
|
"synapse_util_batching_queue_number_queued",
|
||||||
|
"The number of items waiting in the queue across all keys",
|
||||||
|
labels=("name",),
|
||||||
|
caller=lambda: sum(len(v) for v in self._next_values.values()),
|
||||||
|
)
|
||||||
|
|
||||||
|
LaterGauge(
|
||||||
|
"synapse_util_batching_queue_number_of_keys",
|
||||||
|
"The number of distinct keys that have items queued",
|
||||||
|
labels=("name",),
|
||||||
|
caller=lambda: len(self._next_values),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
||||||
|
"""Adds the value to the queue with the given key, returning the result
|
||||||
|
of the processing function for the batch that included the given value.
|
||||||
|
|
||||||
|
The optional `key` argument allows sharding the queue by some key. The
|
||||||
|
queues will then be processed in parallel, i.e. the process batch
|
||||||
|
function will be called in parallel with batched values from a single
|
||||||
|
key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First we create a defer and add it and the value to the list of
|
||||||
|
# pending items.
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._next_values.setdefault(key, []).append((value, d))
|
||||||
|
|
||||||
|
# If we're not currently processing the key fire off a background
|
||||||
|
# process to start processing.
|
||||||
|
if key not in self._processing_keys:
|
||||||
|
run_as_background_process(self._name, self._process_queue, key)
|
||||||
|
|
||||||
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
async def _process_queue(self, key: Hashable) -> None:
|
||||||
|
"""A background task to repeatedly pull things off the queue for the
|
||||||
|
given key and call the `self._process_batch_callback` with the values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if key in self._processing_keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._processing_keys.add(key)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# We purposefully wait a reactor tick to allow us to batch
|
||||||
|
# together requests that we're about to receive. A common
|
||||||
|
# pattern is to call `add_to_queue` multiple times at once, and
|
||||||
|
# deferring to the next reactor tick allows us to batch all of
|
||||||
|
# those up.
|
||||||
|
await self._clock.sleep(0)
|
||||||
|
|
||||||
|
next_values = self._next_values.pop(key, [])
|
||||||
|
if not next_values:
|
||||||
|
# We've exhausted the queue.
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
values = [value for value, _ in next_values]
|
||||||
|
results = await self._process_batch_callback(values)
|
||||||
|
|
||||||
|
for _, deferred in next_values:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
deferred.callback(results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
for _, deferred in next_values:
|
||||||
|
if deferred.called:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
deferred.errback(e)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._processing_keys.discard(key)
|
|
@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
max_entries: int = 1000,
|
max_entries: int = 1000,
|
||||||
keylen: int = 1,
|
|
||||||
tree: bool = False,
|
tree: bool = False,
|
||||||
iterable: bool = False,
|
iterable: bool = False,
|
||||||
apply_cache_factor_from_config: bool = True,
|
apply_cache_factor_from_config: bool = True,
|
||||||
|
@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
# a Deferred.
|
# a Deferred.
|
||||||
self.cache = LruCache(
|
self.cache = LruCache(
|
||||||
max_size=max_entries,
|
max_size=max_entries,
|
||||||
keylen=keylen,
|
|
||||||
cache_name=name,
|
cache_name=name,
|
||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
size_callback=(lambda d: len(d) or 1) if iterable else None,
|
size_callback=(lambda d: len(d) or 1) if iterable else None,
|
||||||
|
|
|
@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
cache = DeferredCache(
|
cache = DeferredCache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
keylen=self.num_args,
|
|
||||||
tree=self.tree,
|
tree=self.tree,
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
) # type: DeferredCache[CacheKey, Any]
|
) # type: DeferredCache[CacheKey, Any]
|
||||||
|
@ -322,8 +321,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given an iterable of keys it looks in the cache to find any hits, then passes
|
||||||
the list of missing keys to the wrapped function.
|
the tuple of missing keys to the wrapped function.
|
||||||
|
|
||||||
Once wrapped, the function returns a Deferred which resolves to the list
|
Once wrapped, the function returns a Deferred which resolves to the list
|
||||||
of results.
|
of results.
|
||||||
|
@ -437,7 +436,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
args_to_call = dict(arg_dict)
|
args_to_call = dict(arg_dict)
|
||||||
args_to_call[self.list_name] = list(missing)
|
# copy the missing set before sending it to the callee, to guard against
|
||||||
|
# modification.
|
||||||
|
args_to_call[self.list_name] = tuple(missing)
|
||||||
|
|
||||||
cached_defers.append(
|
cached_defers.append(
|
||||||
defer.maybeDeferred(
|
defer.maybeDeferred(
|
||||||
|
@ -522,14 +523,14 @@ def cachedList(
|
||||||
|
|
||||||
Used to do batch lookups for an already created cache. A single argument
|
Used to do batch lookups for an already created cache. A single argument
|
||||||
is specified as a list that is iterated through to lookup keys in the
|
is specified as a list that is iterated through to lookup keys in the
|
||||||
original cache. A new list consisting of the keys that weren't in the cache
|
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
|
||||||
get passed to the original function, the result of which is stored in the
|
the cache gets passed to the original function, the result of which is stored in the
|
||||||
cache.
|
cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cached_method_name: The name of the single-item lookup method.
|
cached_method_name: The name of the single-item lookup method.
|
||||||
This is only used to find the cache to use.
|
This is only used to find the cache to use.
|
||||||
list_name: The name of the argument that is the list to use to
|
list_name: The name of the argument that is the iterable to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args: Number of arguments to use as the key in the cache
|
num_args: Number of arguments to use as the key in the cache
|
||||||
(including list_name). Defaults to all named parameters.
|
(including list_name). Defaults to all named parameters.
|
||||||
|
|
|
@ -34,7 +34,7 @@ from typing_extensions import Literal
|
||||||
from synapse.config import cache as cache_config
|
from synapse.config import cache as cache_config
|
||||||
from synapse.util import caches
|
from synapse.util import caches
|
||||||
from synapse.util.caches import CacheMetric, register_cache
|
from synapse.util.caches import CacheMetric, register_cache
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pympler.asizeof import Asizer
|
from pympler.asizeof import Asizer
|
||||||
|
@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]):
|
||||||
self,
|
self,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
cache_name: Optional[str] = None,
|
cache_name: Optional[str] = None,
|
||||||
keylen: int = 1,
|
|
||||||
cache_type: Type[Union[dict, TreeCache]] = dict,
|
cache_type: Type[Union[dict, TreeCache]] = dict,
|
||||||
size_callback: Optional[Callable] = None,
|
size_callback: Optional[Callable] = None,
|
||||||
metrics_collection_callback: Optional[Callable[[], None]] = None,
|
metrics_collection_callback: Optional[Callable[[], None]] = None,
|
||||||
|
@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]):
|
||||||
cache_name: The name of this cache, for the prometheus metrics. If unset,
|
cache_name: The name of this cache, for the prometheus metrics. If unset,
|
||||||
no metrics will be reported on this cache.
|
no metrics will be reported on this cache.
|
||||||
|
|
||||||
keylen: The length of the tuple used as the cache key. Ignored unless
|
|
||||||
cache_type is `TreeCache`.
|
|
||||||
|
|
||||||
cache_type (type):
|
cache_type (type):
|
||||||
type of underlying cache to be used. Typically one of dict
|
type of underlying cache to be used. Typically one of dict
|
||||||
or TreeCache.
|
or TreeCache.
|
||||||
|
@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]):
|
||||||
popped = cache.pop(key)
|
popped = cache.pop(key)
|
||||||
if popped is None:
|
if popped is None:
|
||||||
return
|
return
|
||||||
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
|
# for each deleted node, we now need to remove it from the linked list
|
||||||
|
# and run its callbacks.
|
||||||
|
for leaf in iterate_tree_cache_entry(popped):
|
||||||
delete_node(leaf)
|
delete_node(leaf)
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
|
|
|
@ -1,18 +1,43 @@
|
||||||
from typing import Dict
|
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
SENTINEL = object()
|
SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
class TreeCacheNode(dict):
|
||||||
|
"""The type of nodes in our tree.
|
||||||
|
|
||||||
|
Has its own type so we can distinguish it from real dicts that are stored at the
|
||||||
|
leaves.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TreeCache:
|
class TreeCache:
|
||||||
"""
|
"""
|
||||||
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
|
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
|
||||||
efficiently.
|
efficiently.
|
||||||
Keys must be tuples.
|
Keys must be tuples.
|
||||||
|
|
||||||
|
The data structure is a chain of TreeCacheNodes:
|
||||||
|
root = {key_1: {key_2: _value}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self.root = {} # type: Dict
|
self.root = TreeCacheNode()
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
return self.set(key, value)
|
return self.set(key, value)
|
||||||
|
@ -21,10 +46,23 @@ class TreeCache:
|
||||||
return self.get(key, SENTINEL) is not SENTINEL
|
return self.get(key, SENTINEL) is not SENTINEL
|
||||||
|
|
||||||
def set(self, key, value):
|
def set(self, key, value):
|
||||||
|
if isinstance(value, TreeCacheNode):
|
||||||
|
# this would mean we couldn't tell where our tree ended and the value
|
||||||
|
# started.
|
||||||
|
raise ValueError("Cannot store TreeCacheNodes in a TreeCache")
|
||||||
|
|
||||||
node = self.root
|
node = self.root
|
||||||
for k in key[:-1]:
|
for k in key[:-1]:
|
||||||
node = node.setdefault(k, {})
|
next_node = node.get(k, SENTINEL)
|
||||||
node[key[-1]] = _Entry(value)
|
if next_node is SENTINEL:
|
||||||
|
next_node = node[k] = TreeCacheNode()
|
||||||
|
elif not isinstance(next_node, TreeCacheNode):
|
||||||
|
# this suggests that the caller is not being consistent with its key
|
||||||
|
# length.
|
||||||
|
raise ValueError("value conflicts with an existing subtree")
|
||||||
|
node = next_node
|
||||||
|
|
||||||
|
node[key[-1]] = value
|
||||||
self.size += 1
|
self.size += 1
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
|
@ -33,25 +71,41 @@ class TreeCache:
|
||||||
node = node.get(k, None)
|
node = node.get(k, None)
|
||||||
if node is None:
|
if node is None:
|
||||||
return default
|
return default
|
||||||
return node.get(key[-1], _Entry(default)).value
|
return node.get(key[-1], default)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self.root = {}
|
self.root = TreeCacheNode()
|
||||||
|
|
||||||
def pop(self, key, default=None):
|
def pop(self, key, default=None):
|
||||||
|
"""Remove the given key, or subkey, from the cache
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key or subkey to remove.
|
||||||
|
default: value to return if key is not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the key is not found, 'default'. If the key is complete, the removed
|
||||||
|
value. If the key is partial, the TreeCacheNode corresponding to the part
|
||||||
|
of the tree that was removed.
|
||||||
|
"""
|
||||||
|
# a list of the nodes we have touched on the way down the tree
|
||||||
nodes = []
|
nodes = []
|
||||||
|
|
||||||
node = self.root
|
node = self.root
|
||||||
for k in key[:-1]:
|
for k in key[:-1]:
|
||||||
node = node.get(k, None)
|
node = node.get(k, None)
|
||||||
nodes.append(node) # don't add the root node
|
|
||||||
if node is None:
|
if node is None:
|
||||||
return default
|
return default
|
||||||
|
if not isinstance(node, TreeCacheNode):
|
||||||
|
# we've gone off the end of the tree
|
||||||
|
raise ValueError("pop() key too long")
|
||||||
|
nodes.append(node) # don't add the root node
|
||||||
popped = node.pop(key[-1], SENTINEL)
|
popped = node.pop(key[-1], SENTINEL)
|
||||||
if popped is SENTINEL:
|
if popped is SENTINEL:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
# working back up the tree, clear out any nodes that are now empty
|
||||||
node_and_keys = list(zip(nodes, key))
|
node_and_keys = list(zip(nodes, key))
|
||||||
node_and_keys.reverse()
|
node_and_keys.reverse()
|
||||||
node_and_keys.append((self.root, None))
|
node_and_keys.append((self.root, None))
|
||||||
|
@ -61,14 +115,15 @@ class TreeCache:
|
||||||
|
|
||||||
if n:
|
if n:
|
||||||
break
|
break
|
||||||
|
# found an empty node: remove it from its parent, and loop.
|
||||||
node_and_keys[i + 1][0].pop(k)
|
node_and_keys[i + 1][0].pop(k)
|
||||||
|
|
||||||
popped, cnt = _strip_and_count_entires(popped)
|
cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
|
||||||
self.size -= cnt
|
self.size -= cnt
|
||||||
return popped
|
return popped
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return list(iterate_tree_cache_entry(self.root))
|
return iterate_tree_cache_entry(self.root)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.size
|
return self.size
|
||||||
|
@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d):
|
||||||
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
|
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
|
||||||
can contain dicts.
|
can contain dicts.
|
||||||
"""
|
"""
|
||||||
if isinstance(d, dict):
|
if isinstance(d, TreeCacheNode):
|
||||||
for value_d in d.values():
|
for value_d in d.values():
|
||||||
for value in iterate_tree_cache_entry(value_d):
|
for value in iterate_tree_cache_entry(value_d):
|
||||||
yield value
|
yield value
|
||||||
else:
|
|
||||||
if isinstance(d, _Entry):
|
|
||||||
yield d.value
|
|
||||||
else:
|
else:
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
|
|
||||||
class _Entry:
|
|
||||||
__slots__ = ["value"]
|
|
||||||
|
|
||||||
def __init__(self, value):
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_and_count_entires(d):
|
|
||||||
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
|
|
||||||
value or a dictionary with _Entry's replaced by their values.
|
|
||||||
|
|
||||||
Also returns the count of _Entry's
|
|
||||||
"""
|
|
||||||
if isinstance(d, dict):
|
|
||||||
cnt = 0
|
|
||||||
for key, value in d.items():
|
|
||||||
v, n = _strip_and_count_entires(value)
|
|
||||||
d[key] = v
|
|
||||||
cnt += n
|
|
||||||
return d, cnt
|
|
||||||
else:
|
|
||||||
return d.value, 1
|
|
||||||
|
|
|
@ -17,15 +17,15 @@ import hashlib
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
|
|
||||||
|
|
||||||
def sha256_and_url_safe_base64(input_text):
|
def sha256_and_url_safe_base64(input_text: str) -> str:
|
||||||
"""SHA256 hash an input string, encode the digest as url-safe base64, and
|
"""SHA256 hash an input string, encode the digest as url-safe base64, and
|
||||||
return
|
return
|
||||||
|
|
||||||
:param input_text: string to hash
|
Args:
|
||||||
:type input_text: str
|
input_text: string to hash
|
||||||
|
|
||||||
:returns a sha256 hashed and url-safe base64 encoded digest
|
returns:
|
||||||
:rtype: str
|
A sha256 hashed and url-safe base64 encoded digest
|
||||||
"""
|
"""
|
||||||
digest = hashlib.sha256(input_text.encode()).digest()
|
digest = hashlib.sha256(input_text.encode()).digest()
|
||||||
return unpaddedbase64.encode_base64(digest, urlsafe=True)
|
return unpaddedbase64.encode_base64(digest, urlsafe=True)
|
||||||
|
|
|
@ -30,12 +30,12 @@ from typing import (
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
|
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
|
||||||
"""batch an iterable up into tuples with a maximum size
|
"""batch an iterable up into tuples with a maximum size
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
iterable (iterable): the iterable to slice
|
iterable: the iterable to slice
|
||||||
size (int): the maximum batch size
|
size: the maximum batch size
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
an iterator over the chunks
|
an iterator over the chunks
|
||||||
|
@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
|
||||||
return iter(lambda: tuple(islice(sourceiter, size)), ())
|
return iter(lambda: tuple(islice(sourceiter, size)), ())
|
||||||
|
|
||||||
|
|
||||||
ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
|
def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
|
||||||
|
|
||||||
|
|
||||||
def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
|
|
||||||
"""Split the given sequence into chunks of the given size
|
"""Split the given sequence into chunks of the given size
|
||||||
|
|
||||||
The last chunk may be shorter than the given size.
|
The last chunk may be shorter than the given size.
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import itertools
|
import itertools
|
||||||
|
from types import ModuleType
|
||||||
from typing import Any, Iterable, Tuple, Type
|
from typing import Any, Iterable, Tuple, Type
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
|
@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
|
||||||
|
|
||||||
# We need to import the module, and then pick the class out of
|
# We need to import the module, and then pick the class out of
|
||||||
# that, so we split based on the last dot.
|
# that, so we split based on the last dot.
|
||||||
module, clz = modulename.rsplit(".", 1)
|
module_name, clz = modulename.rsplit(".", 1)
|
||||||
module = importlib.import_module(module)
|
module = importlib.import_module(module_name)
|
||||||
provider_class = getattr(module, clz)
|
provider_class = getattr(module, clz)
|
||||||
|
|
||||||
# Load the module config. If None, pass an empty dictionary instead
|
# Load the module config. If None, pass an empty dictionary instead
|
||||||
|
@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
|
||||||
return provider_class, provider_config
|
return provider_class, provider_config
|
||||||
|
|
||||||
|
|
||||||
def load_python_module(location: str):
|
def load_python_module(location: str) -> ModuleType:
|
||||||
"""Load a python module, and return a reference to its global namespace
|
"""Load a python module, and return a reference to its global namespace
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
location (str): path to the module
|
location: path to the module
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
python module object
|
python module object
|
||||||
|
|
|
@ -17,17 +17,17 @@ import phonenumbers
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
|
||||||
def phone_number_to_msisdn(country, number):
|
def phone_number_to_msisdn(country: str, number: str) -> str:
|
||||||
"""
|
"""
|
||||||
Takes an ISO-3166-1 2 letter country code and phone number and
|
Takes an ISO-3166-1 2 letter country code and phone number and
|
||||||
returns an msisdn representing the canonical version of that
|
returns an msisdn representing the canonical version of that
|
||||||
phone number.
|
phone number.
|
||||||
Args:
|
Args:
|
||||||
country (str): ISO-3166-1 2 letter country code
|
country: ISO-3166-1 2 letter country code
|
||||||
number (str): Phone number in a national or international format
|
number: Phone number in a national or international format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(str) The canonical form of the phone number, as an msisdn
|
The canonical form of the phone number, as an msisdn
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if the number could not be parsed.
|
SynapseError if the number could not be parsed.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
|
||||||
retry_timings = await store.get_destination_retry_timings(destination)
|
retry_timings = await store.get_destination_retry_timings(destination)
|
||||||
|
|
||||||
if retry_timings:
|
if retry_timings:
|
||||||
failure_ts = retry_timings["failure_ts"]
|
failure_ts = retry_timings.failure_ts
|
||||||
retry_last_ts, retry_interval = (
|
retry_last_ts = retry_timings.retry_last_ts
|
||||||
retry_timings["retry_last_ts"],
|
retry_interval = retry_timings.retry_interval
|
||||||
retry_timings["retry_interval"],
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(clock.time_msec())
|
now = int(clock.time_msec())
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
|
import secrets
|
||||||
import string
|
import string
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
@ -35,26 +35,27 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
|
||||||
#
|
#
|
||||||
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||||
|
|
||||||
# random_string and random_string_with_symbols are used for a range of things,
|
|
||||||
# some cryptographically important, some less so. We use SystemRandom to make sure
|
|
||||||
# we get cryptographically-secure randoms.
|
|
||||||
rand = random.SystemRandom()
|
|
||||||
|
|
||||||
|
|
||||||
def random_string(length: int) -> str:
|
def random_string(length: int) -> str:
|
||||||
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
|
"""Generate a cryptographically secure string of random letters.
|
||||||
|
|
||||||
|
Drawn from the characters: `a-z` and `A-Z`
|
||||||
|
"""
|
||||||
|
return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
def random_string_with_symbols(length: int) -> str:
|
def random_string_with_symbols(length: int) -> str:
|
||||||
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
|
"""Generate a cryptographically secure string of random letters/numbers/symbols.
|
||||||
|
|
||||||
|
Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
|
||||||
|
"""
|
||||||
|
return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s: bytes) -> bool:
|
def is_ascii(s: bytes) -> bool:
|
||||||
try:
|
try:
|
||||||
s.decode("ascii").encode("ascii")
|
s.decode("ascii").encode("ascii")
|
||||||
except UnicodeDecodeError:
|
except UnicodeError:
|
||||||
return False
|
|
||||||
except UnicodeEncodeError:
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
102
synctl
102
synctl
|
@ -24,12 +24,13 @@ import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from synapse.config import find_config_files
|
from synapse.config import find_config_files
|
||||||
|
|
||||||
SYNAPSE = [sys.executable, "-m", "synapse.app.homeserver"]
|
MAIN_PROCESS = "synapse.app.homeserver"
|
||||||
|
|
||||||
GREEN = "\x1b[1;32m"
|
GREEN = "\x1b[1;32m"
|
||||||
YELLOW = "\x1b[1;33m"
|
YELLOW = "\x1b[1;33m"
|
||||||
|
@ -68,71 +69,37 @@ def abort(message, colour=RED, stream=sys.stderr):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def start(configfile: str, daemonize: bool = True) -> bool:
|
def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) -> bool:
|
||||||
"""Attempts to start synapse.
|
"""Attempts to start a synapse main or worker process.
|
||||||
Args:
|
Args:
|
||||||
configfile: path to a yaml synapse config file
|
pidfile: the pidfile we expect the process to create
|
||||||
daemonize: whether to daemonize synapse or keep it attached to the current
|
app: the python module to run
|
||||||
session
|
config_files: config files to pass to synapse
|
||||||
|
daemonize: if True, will include a --daemonize argument to synapse
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the process started successfully
|
True if the process started successfully or was already running
|
||||||
False if there was an error starting the process
|
False if there was an error starting the process
|
||||||
|
|
||||||
If deamonize is False it will only return once synapse exits.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
write("Starting ...")
|
if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
|
||||||
args = SYNAPSE
|
print(app + " already running")
|
||||||
|
return True
|
||||||
|
|
||||||
|
args = [sys.executable, "-m", app]
|
||||||
|
for c in config_files:
|
||||||
|
args += ["-c", c]
|
||||||
if daemonize:
|
if daemonize:
|
||||||
args.extend(["--daemonize", "-c", configfile])
|
args.append("--daemonize")
|
||||||
else:
|
|
||||||
args.extend(["-c", configfile])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(args)
|
subprocess.check_call(args)
|
||||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
write("started %s(%s)" % (app, ",".join(config_files)), colour=GREEN)
|
||||||
return True
|
return True
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
write(
|
write(
|
||||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
"error starting %s(%s) (exit code: %d); see above for logs"
|
||||||
colour=RED,
|
% (app, ",".join(config_files), e.returncode),
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
|
|
||||||
"""Attempts to start a synapse worker.
|
|
||||||
Args:
|
|
||||||
app: name of the worker's appservice
|
|
||||||
configfile: path to a yaml synapse config file
|
|
||||||
worker_configfile: path to worker specific yaml synapse file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the process started successfully
|
|
||||||
False if there was an error starting the process
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
app,
|
|
||||||
"-c",
|
|
||||||
configfile,
|
|
||||||
"-c",
|
|
||||||
worker_configfile,
|
|
||||||
"--daemonize",
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.check_call(args)
|
|
||||||
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
|
|
||||||
return True
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
write(
|
|
||||||
"error starting %s(%r) (exit code: %d); see above for logs"
|
|
||||||
% (app, worker_configfile, e.returncode),
|
|
||||||
colour=RED,
|
colour=RED,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
@ -224,10 +191,11 @@ def main():
|
||||||
|
|
||||||
if not os.path.exists(configfile):
|
if not os.path.exists(configfile):
|
||||||
write(
|
write(
|
||||||
"No config file found\n"
|
f"Config file {configfile} does not exist.\n"
|
||||||
"To generate a config file, run '%s -c %s --generate-config"
|
f"To generate a config file, run:\n"
|
||||||
" --server-name=<server name> --report-stats=<yes/no>'\n"
|
f" {sys.executable} -m {MAIN_PROCESS}"
|
||||||
% (" ".join(SYNAPSE), options.configfile),
|
f" -c {configfile} --generate-config"
|
||||||
|
f" --server-name=<server name> --report-stats=<yes/no>\n",
|
||||||
stream=sys.stderr,
|
stream=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -323,7 +291,7 @@ def main():
|
||||||
has_stopped = False
|
has_stopped = False
|
||||||
|
|
||||||
if start_stop_synapse:
|
if start_stop_synapse:
|
||||||
if not stop(pidfile, "synapse.app.homeserver"):
|
if not stop(pidfile, MAIN_PROCESS):
|
||||||
has_stopped = False
|
has_stopped = False
|
||||||
if not has_stopped and action == "stop":
|
if not has_stopped and action == "stop":
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -346,30 +314,24 @@ def main():
|
||||||
if action == "start" or action == "restart":
|
if action == "start" or action == "restart":
|
||||||
error = False
|
error = False
|
||||||
if start_stop_synapse:
|
if start_stop_synapse:
|
||||||
# Check if synapse is already running
|
if not start(pidfile, MAIN_PROCESS, (configfile,), options.daemonize):
|
||||||
if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
|
|
||||||
abort("synapse.app.homeserver already running")
|
|
||||||
|
|
||||||
if not start(configfile, bool(options.daemonize)):
|
|
||||||
error = True
|
error = True
|
||||||
|
|
||||||
for worker in workers:
|
for worker in workers:
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
|
|
||||||
# Skip starting a worker if its already running
|
|
||||||
if os.path.exists(worker.pidfile) and pid_running(
|
|
||||||
int(open(worker.pidfile).read())
|
|
||||||
):
|
|
||||||
print(worker.app + " already running")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if worker.cache_factor:
|
if worker.cache_factor:
|
||||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
|
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
|
||||||
|
|
||||||
for cache_name, factor in worker.cache_factors.items():
|
for cache_name, factor in worker.cache_factors.items():
|
||||||
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
|
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
|
||||||
|
|
||||||
if not start_worker(worker.app, configfile, worker.configfile):
|
if not start(
|
||||||
|
worker.pidfile,
|
||||||
|
worker.app,
|
||||||
|
(configfile, worker.configfile),
|
||||||
|
options.daemonize,
|
||||||
|
):
|
||||||
error = True
|
error = True
|
||||||
|
|
||||||
# Reset env back to the original
|
# Reset env back to the original
|
||||||
|
|
|
@ -302,11 +302,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the expected presence updates were sent
|
# Check that the expected presence updates were sent
|
||||||
expected_users = [
|
# We explicitly compare using sets as we expect that calling
|
||||||
|
# module_api.send_local_online_presence_to will create a presence
|
||||||
|
# update that is a duplicate of the specified user's current presence.
|
||||||
|
# These are sent to clients and will be picked up below, thus we use a
|
||||||
|
# set to deduplicate. We're just interested that non-offline updates were
|
||||||
|
# sent out for each user ID.
|
||||||
|
expected_users = {
|
||||||
self.other_user_id,
|
self.other_user_id,
|
||||||
self.presence_receiving_user_one_id,
|
self.presence_receiving_user_one_id,
|
||||||
self.presence_receiving_user_two_id,
|
self.presence_receiving_user_two_id,
|
||||||
]
|
}
|
||||||
|
found_users = set()
|
||||||
|
|
||||||
calls = (
|
calls = (
|
||||||
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
||||||
|
@ -326,12 +333,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
# EDUs can contain multiple presence updates
|
# EDUs can contain multiple presence updates
|
||||||
for presence_update in edu["content"]["push"]:
|
for presence_update in edu["content"]["push"]:
|
||||||
# Check for presence updates that contain the user IDs we're after
|
# Check for presence updates that contain the user IDs we're after
|
||||||
expected_users.remove(presence_update["user_id"])
|
found_users.add(presence_update["user_id"])
|
||||||
|
|
||||||
# Ensure that no offline states are being sent out
|
# Ensure that no offline states are being sent out
|
||||||
self.assertNotEqual(presence_update["presence"], "offline")
|
self.assertNotEqual(presence_update["presence"], "offline")
|
||||||
|
|
||||||
self.assertEqual(len(expected_users), 0)
|
self.assertEqual(found_users, expected_users)
|
||||||
|
|
||||||
|
|
||||||
def send_presence_update(
|
def send_presence_update(
|
||||||
|
|
|
@ -32,13 +32,19 @@ from synapse.handlers.presence import (
|
||||||
handle_timeout,
|
handle_timeout,
|
||||||
handle_update,
|
handle_update,
|
||||||
)
|
)
|
||||||
|
from synapse.rest import admin
|
||||||
from synapse.rest.client.v1 import room
|
from synapse.rest.client.v1 import room
|
||||||
from synapse.types import UserID, get_domain_from_id
|
from synapse.types import UserID, get_domain_from_id
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class PresenceUpdateTestCase(unittest.TestCase):
|
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [admin.register_servlets]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, homeserver):
|
||||||
|
self.store = homeserver.get_datastore()
|
||||||
|
|
||||||
def test_offline_to_online(self):
|
def test_offline_to_online(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
|
@ -292,6 +298,45 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_persisting_presence_updates(self):
|
||||||
|
"""Tests that the latest presence state for each user is persisted correctly"""
|
||||||
|
# Create some test users and presence states for them
|
||||||
|
presence_states = []
|
||||||
|
for i in range(5):
|
||||||
|
user_id = self.register_user(f"user_{i}", "password")
|
||||||
|
|
||||||
|
presence_state = UserPresenceState(
|
||||||
|
user_id=user_id,
|
||||||
|
state="online",
|
||||||
|
last_active_ts=1,
|
||||||
|
last_federation_update_ts=1,
|
||||||
|
last_user_sync_ts=1,
|
||||||
|
status_msg="I'm online!",
|
||||||
|
currently_active=True,
|
||||||
|
)
|
||||||
|
presence_states.append(presence_state)
|
||||||
|
|
||||||
|
# Persist these presence updates to the database
|
||||||
|
self.get_success(self.store.update_presence(presence_states))
|
||||||
|
|
||||||
|
# Check that each update is present in the database
|
||||||
|
db_presence_states = self.get_success(
|
||||||
|
self.store.get_all_presence_updates(
|
||||||
|
instance_name="master",
|
||||||
|
last_id=0,
|
||||||
|
current_id=len(presence_states) + 1,
|
||||||
|
limit=len(presence_states),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract presence update user ID and state information into lists of tuples
|
||||||
|
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
|
||||||
|
presence_states = [(ps.user_id, ps.state) for ps in presence_states]
|
||||||
|
|
||||||
|
# Compare what we put into the storage with what we got out.
|
||||||
|
# They should be identical.
|
||||||
|
self.assertEqual(presence_states, db_presence_states)
|
||||||
|
|
||||||
|
|
||||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
def test_idle_timer(self):
|
def test_idle_timer(self):
|
||||||
|
|
|
@ -89,14 +89,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_source = hs.get_event_sources().sources["typing"]
|
self.event_source = hs.get_event_sources().sources["typing"]
|
||||||
|
|
||||||
self.datastore = hs.get_datastore()
|
self.datastore = hs.get_datastore()
|
||||||
retry_timings_res = {
|
|
||||||
"destination": "",
|
|
||||||
"retry_last_ts": 0,
|
|
||||||
"retry_interval": 0,
|
|
||||||
"failure_ts": None,
|
|
||||||
}
|
|
||||||
self.datastore.get_destination_retry_timings = Mock(
|
self.datastore.get_destination_retry_timings = Mock(
|
||||||
return_value=defer.succeed(retry_timings_res)
|
return_value=defer.succeed(None)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_device_updates_by_remote = Mock(
|
self.datastore.get_device_updates_by_remote = Mock(
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes
|
from synapse.api.constants import EduTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.federation.units import Transaction
|
from synapse.federation.units import Transaction
|
||||||
|
@ -22,11 +24,13 @@ from synapse.rest.client.v1 import login, presence, room
|
||||||
from synapse.types import create_requester
|
from synapse.types import create_requester
|
||||||
|
|
||||||
from tests.events.test_presence_router import send_presence_update, sync_presence
|
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||||
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.test_utils.event_injection import inject_member_event
|
from tests.test_utils.event_injection import inject_member_event
|
||||||
from tests.unittest import FederatingHomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
|
|
||||||
class ModuleApiTestCase(FederatingHomeserverTestCase):
|
class ModuleApiTestCase(HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
@ -217,97 +221,16 @@ class ModuleApiTestCase(FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(is_in_public_rooms)
|
self.assertFalse(is_in_public_rooms)
|
||||||
|
|
||||||
# The ability to send federation is required by send_local_online_presence_to.
|
|
||||||
@override_config({"send_federation": True})
|
|
||||||
def test_send_local_online_presence_to(self):
|
def test_send_local_online_presence_to(self):
|
||||||
"""Tests that send_local_presence_to_users sends local online presence to local users."""
|
# Test sending local online presence to users from the main process
|
||||||
# Create a user who will send presence updates
|
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
|
||||||
self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
|
|
||||||
self.presence_receiver_tok = self.login("presence_receiver", "monkey")
|
|
||||||
|
|
||||||
# And another user that will send presence updates out
|
|
||||||
self.presence_sender_id = self.register_user("presence_sender", "monkey")
|
|
||||||
self.presence_sender_tok = self.login("presence_sender", "monkey")
|
|
||||||
|
|
||||||
# Put them in a room together so they will receive each other's presence updates
|
|
||||||
room_id = self.helper.create_room_as(
|
|
||||||
self.presence_receiver_id,
|
|
||||||
tok=self.presence_receiver_tok,
|
|
||||||
)
|
|
||||||
self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
|
|
||||||
|
|
||||||
# Presence sender comes online
|
|
||||||
send_presence_update(
|
|
||||||
self,
|
|
||||||
self.presence_sender_id,
|
|
||||||
self.presence_sender_tok,
|
|
||||||
"online",
|
|
||||||
"I'm online!",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Presence receiver should have received it
|
|
||||||
presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
|
|
||||||
self.assertEqual(len(presence_updates), 1)
|
|
||||||
|
|
||||||
presence_update = presence_updates[0] # type: UserPresenceState
|
|
||||||
self.assertEqual(presence_update.user_id, self.presence_sender_id)
|
|
||||||
self.assertEqual(presence_update.state, "online")
|
|
||||||
|
|
||||||
# Syncing again should result in no presence updates
|
|
||||||
presence_updates, sync_token = sync_presence(
|
|
||||||
self, self.presence_receiver_id, sync_token
|
|
||||||
)
|
|
||||||
self.assertEqual(len(presence_updates), 0)
|
|
||||||
|
|
||||||
# Trigger sending local online presence
|
|
||||||
self.get_success(
|
|
||||||
self.module_api.send_local_online_presence_to(
|
|
||||||
[
|
|
||||||
self.presence_receiver_id,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Presence receiver should have received online presence again
|
|
||||||
presence_updates, sync_token = sync_presence(
|
|
||||||
self, self.presence_receiver_id, sync_token
|
|
||||||
)
|
|
||||||
self.assertEqual(len(presence_updates), 1)
|
|
||||||
|
|
||||||
presence_update = presence_updates[0] # type: UserPresenceState
|
|
||||||
self.assertEqual(presence_update.user_id, self.presence_sender_id)
|
|
||||||
self.assertEqual(presence_update.state, "online")
|
|
||||||
|
|
||||||
# Presence sender goes offline
|
|
||||||
send_presence_update(
|
|
||||||
self,
|
|
||||||
self.presence_sender_id,
|
|
||||||
self.presence_sender_tok,
|
|
||||||
"offline",
|
|
||||||
"I slink back into the darkness.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trigger sending local online presence
|
|
||||||
self.get_success(
|
|
||||||
self.module_api.send_local_online_presence_to(
|
|
||||||
[
|
|
||||||
self.presence_receiver_id,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Presence receiver should *not* have received offline state
|
|
||||||
presence_updates, sync_token = sync_presence(
|
|
||||||
self, self.presence_receiver_id, sync_token
|
|
||||||
)
|
|
||||||
self.assertEqual(len(presence_updates), 0)
|
|
||||||
|
|
||||||
@override_config({"send_federation": True})
|
@override_config({"send_federation": True})
|
||||||
def test_send_local_online_presence_to_federation(self):
|
def test_send_local_online_presence_to_federation(self):
|
||||||
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
|
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
|
||||||
# Create a user who will send presence updates
|
# Create a user who will send presence updates
|
||||||
self.presence_sender_id = self.register_user("presence_sender", "monkey")
|
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
|
||||||
self.presence_sender_tok = self.login("presence_sender", "monkey")
|
self.presence_sender_tok = self.login("presence_sender1", "monkey")
|
||||||
|
|
||||||
# And a room they're a part of
|
# And a room they're a part of
|
||||||
room_id = self.helper.create_room_as(
|
room_id = self.helper.create_room_as(
|
||||||
|
@ -374,3 +297,209 @@ class ModuleApiTestCase(FederatingHomeserverTestCase):
|
||||||
found_update = True
|
found_update = True
|
||||||
|
|
||||||
self.assertTrue(found_update)
|
self.assertTrue(found_update)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
"""For testing ModuleApi functionality in a multi-worker setup"""
|
||||||
|
|
||||||
|
# Testing stream ID replication from the main to worker processes requires postgres
|
||||||
|
# (due to needing `MultiWriterIdGenerator`).
|
||||||
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
presence.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
conf = super().default_config()
|
||||||
|
conf["redis"] = {"enabled": "true"}
|
||||||
|
conf["stream_writers"] = {"presence": ["presence_writer"]}
|
||||||
|
conf["instance_map"] = {
|
||||||
|
"presence_writer": {"host": "testserv", "port": 1001},
|
||||||
|
}
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, homeserver):
|
||||||
|
self.module_api = homeserver.get_module_api()
|
||||||
|
self.sync_handler = homeserver.get_sync_handler()
|
||||||
|
|
||||||
|
def test_send_local_online_presence_to_workers(self):
|
||||||
|
# Test sending local online presence to users from a worker process
|
||||||
|
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_sending_local_online_presence_to_local_user(
|
||||||
|
test_case: HomeserverTestCase, test_with_workers: bool = False
|
||||||
|
):
|
||||||
|
"""Tests that send_local_presence_to_users sends local online presence to local users.
|
||||||
|
|
||||||
|
This simultaneously tests two different usecases:
|
||||||
|
* Testing that this method works when either called from a worker or the main process.
|
||||||
|
- We test this by calling this method from both a TestCase that runs in monolith mode, and one that
|
||||||
|
runs with a main and generic_worker.
|
||||||
|
* Testing that multiple devices syncing simultaneously will all receive a snapshot of local,
|
||||||
|
online presence - but only once per device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_with_workers: If True, this method will call ModuleApi.send_local_online_presence_to on a
|
||||||
|
worker process. The test users will still sync with the main process. The purpose of testing
|
||||||
|
with a worker is to check whether a Synapse module running on a worker can inform other workers/
|
||||||
|
the main process that they should include additional presence when a user next syncs.
|
||||||
|
"""
|
||||||
|
if test_with_workers:
|
||||||
|
# Create a worker process to make module_api calls against
|
||||||
|
worker_hs = test_case.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker", {"worker_name": "presence_writer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a user who will send presence updates
|
||||||
|
test_case.presence_receiver_id = test_case.register_user(
|
||||||
|
"presence_receiver1", "monkey"
|
||||||
|
)
|
||||||
|
test_case.presence_receiver_tok = test_case.login("presence_receiver1", "monkey")
|
||||||
|
|
||||||
|
# And another user that will send presence updates out
|
||||||
|
test_case.presence_sender_id = test_case.register_user("presence_sender2", "monkey")
|
||||||
|
test_case.presence_sender_tok = test_case.login("presence_sender2", "monkey")
|
||||||
|
|
||||||
|
# Put them in a room together so they will receive each other's presence updates
|
||||||
|
room_id = test_case.helper.create_room_as(
|
||||||
|
test_case.presence_receiver_id,
|
||||||
|
tok=test_case.presence_receiver_tok,
|
||||||
|
)
|
||||||
|
test_case.helper.join(
|
||||||
|
room_id, test_case.presence_sender_id, tok=test_case.presence_sender_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
# Presence sender comes online
|
||||||
|
send_presence_update(
|
||||||
|
test_case,
|
||||||
|
test_case.presence_sender_id,
|
||||||
|
test_case.presence_sender_tok,
|
||||||
|
"online",
|
||||||
|
"I'm online!",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Presence receiver should have received it
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id
|
||||||
|
)
|
||||||
|
test_case.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
|
||||||
|
test_case.assertEqual(presence_update.state, "online")
|
||||||
|
|
||||||
|
if test_with_workers:
|
||||||
|
# Replicate the current sync presence token from the main process to the worker process.
|
||||||
|
# We need to do this so that the worker process knows the current presence stream ID to
|
||||||
|
# insert into the database when we call ModuleApi.send_local_online_presence_to.
|
||||||
|
test_case.replicate()
|
||||||
|
|
||||||
|
# Syncing again should result in no presence updates
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
test_case.assertEqual(len(presence_updates), 0)
|
||||||
|
|
||||||
|
# We do an (initial) sync with a second "device" now, getting a new sync token.
|
||||||
|
# We'll use this in a moment.
|
||||||
|
_, sync_token_second_device = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
|
||||||
|
if test_with_workers:
|
||||||
|
module_api_to_use = worker_hs.get_module_api()
|
||||||
|
else:
|
||||||
|
module_api_to_use = test_case.module_api
|
||||||
|
|
||||||
|
# Trigger sending local online presence. We expect this information
|
||||||
|
# to be saved to the database where all processes can access it.
|
||||||
|
# Note that we're syncing via the master.
|
||||||
|
d = module_api_to_use.send_local_online_presence_to(
|
||||||
|
[
|
||||||
|
test_case.presence_receiver_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
d = defer.ensureDeferred(d)
|
||||||
|
|
||||||
|
if test_with_workers:
|
||||||
|
# In order for the required presence_set_state replication request to occur between the
|
||||||
|
# worker and main process, we need to pump the reactor. Otherwise, the coordinator that
|
||||||
|
# reads the request on the main process won't do so, and the request will time out.
|
||||||
|
while not d.called:
|
||||||
|
test_case.reactor.advance(0.1)
|
||||||
|
|
||||||
|
test_case.get_success(d)
|
||||||
|
|
||||||
|
# The presence receiver should have received online presence again.
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
test_case.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
|
||||||
|
test_case.assertEqual(presence_update.state, "online")
|
||||||
|
|
||||||
|
# We attempt to sync with the second sync token we received above - just to check that
|
||||||
|
# multiple syncing devices will each receive the necessary online presence.
|
||||||
|
presence_updates, sync_token_second_device = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id, sync_token_second_device
|
||||||
|
)
|
||||||
|
test_case.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
|
||||||
|
test_case.assertEqual(presence_update.state, "online")
|
||||||
|
|
||||||
|
# However, if we now sync with either "device", we won't receive another burst of online presence
|
||||||
|
# until the API is called again sometime in the future
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now we check that we don't receive *offline* updates using ModuleApi.send_local_online_presence_to.
|
||||||
|
|
||||||
|
# Presence sender goes offline
|
||||||
|
send_presence_update(
|
||||||
|
test_case,
|
||||||
|
test_case.presence_sender_id,
|
||||||
|
test_case.presence_sender_tok,
|
||||||
|
"offline",
|
||||||
|
"I slink back into the darkness.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Presence receiver should have received the updated, offline state
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
test_case.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
# Now trigger sending local online presence.
|
||||||
|
d = module_api_to_use.send_local_online_presence_to(
|
||||||
|
[
|
||||||
|
test_case.presence_receiver_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
d = defer.ensureDeferred(d)
|
||||||
|
|
||||||
|
if test_with_workers:
|
||||||
|
# In order for the required presence_set_state replication request to occur between the
|
||||||
|
# worker and main process, we need to pump the reactor. Otherwise, the coordinator that
|
||||||
|
# reads the request on the main process won't do so, and the request will time out.
|
||||||
|
while not d.called:
|
||||||
|
test_case.reactor.advance(0.1)
|
||||||
|
|
||||||
|
test_case.get_success(d)
|
||||||
|
|
||||||
|
# Presence receiver should *not* have received offline state
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
test_case, test_case.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
test_case.assertEqual(len(presence_updates), 0)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
"""Checks event persisting sharding works"""
|
"""Checks event persisting sharding works"""
|
||||||
|
|
||||||
# Event persister sharding requires postgres (due to needing
|
# Event persister sharding requires postgres (due to needing
|
||||||
# `MutliWriterIdGenerator`).
|
# `MultiWriterIdGenerator`).
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
skip = "Requires Postgres"
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.databases.main.transactions import DestinationRetryTimings
|
||||||
from synapse.util.retryutils import MAX_RETRY_INTERVAL
|
from synapse.util.retryutils import MAX_RETRY_INTERVAL
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
@ -36,8 +37,11 @@ class TransactionStoreTestCase(HomeserverTestCase):
|
||||||
d = self.store.get_destination_retry_timings("example.com")
|
d = self.store.get_destination_retry_timings("example.com")
|
||||||
r = self.get_success(d)
|
r = self.get_success(d)
|
||||||
|
|
||||||
self.assert_dict(
|
self.assertEqual(
|
||||||
{"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r
|
DestinationRetryTimings(
|
||||||
|
retry_last_ts=50, retry_interval=100, failure_ts=1000
|
||||||
|
),
|
||||||
|
r,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initial_set_transactions(self):
|
def test_initial_set_transactions(self):
|
||||||
|
|
|
@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
with LoggingContext("c1") as c1:
|
with LoggingContext("c1") as c1:
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
obj.mock.return_value = {10: "fish", 20: "chips"}
|
obj.mock.return_value = {10: "fish", 20: "chips"}
|
||||||
|
|
||||||
|
# start the lookup off
|
||||||
d1 = obj.list_fn([10, 20], 2)
|
d1 = obj.list_fn([10, 20], 2)
|
||||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||||
r = yield d1
|
r = yield d1
|
||||||
self.assertEqual(current_context(), c1)
|
self.assertEqual(current_context(), c1)
|
||||||
obj.mock.assert_called_once_with([10, 20], 2)
|
obj.mock.assert_called_once_with((10, 20), 2)
|
||||||
self.assertEqual(r, {10: "fish", 20: "chips"})
|
self.assertEqual(r, {10: "fish", 20: "chips"})
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
# a call with different params should call the mock again
|
# a call with different params should call the mock again
|
||||||
obj.mock.return_value = {30: "peas"}
|
obj.mock.return_value = {30: "peas"}
|
||||||
r = yield obj.list_fn([20, 30], 2)
|
r = yield obj.list_fn([20, 30], 2)
|
||||||
obj.mock.assert_called_once_with([30], 2)
|
obj.mock.assert_called_once_with((30,), 2)
|
||||||
self.assertEqual(r, {20: "chips", 30: "peas"})
|
self.assertEqual(r, {20: "chips", 30: "peas"})
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
|
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
|
||||||
|
|
||||||
|
# we should also be able to use a (single-use) iterable, and should
|
||||||
|
# deduplicate the keys
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
obj.mock.return_value = {40: "gravy"}
|
||||||
|
iterable = (x for x in [10, 40, 40])
|
||||||
|
r = yield obj.list_fn(iterable, 2)
|
||||||
|
obj.mock.assert_called_once_with((40,), 2)
|
||||||
|
self.assertEqual(r, {10: "fish", 40: "gravy"})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_invalidate(self):
|
def test_invalidate(self):
|
||||||
"""Make sure that invalidation callbacks are called."""
|
"""Make sure that invalidation callbacks are called."""
|
||||||
|
@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
# cache miss
|
# cache miss
|
||||||
obj.mock.return_value = {10: "fish", 20: "chips"}
|
obj.mock.return_value = {10: "fish", 20: "chips"}
|
||||||
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||||
obj.mock.assert_called_once_with([10, 20], 2)
|
obj.mock.assert_called_once_with((10, 20), 2)
|
||||||
self.assertEqual(r1, {10: "fish", 20: "chips"})
|
self.assertEqual(r1, {10: "fish", 20: "chips"})
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
|
169
tests/util/test_batching_queue.py
Normal file
169
tests/util/test_batching_queue.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
from synapse.util.batching_queue import BatchingQueue
|
||||||
|
|
||||||
|
from tests.server import get_clock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class BatchingQueueTestCase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.clock, hs_clock = get_clock()
|
||||||
|
|
||||||
|
self._pending_calls = []
|
||||||
|
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
||||||
|
|
||||||
|
async def _process_queue(self, values):
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._pending_calls.append((values, d))
|
||||||
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
"""Tests the basic case of calling `add_to_queue` once and having
|
||||||
|
`_process_queue` return.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
|
||||||
|
|
||||||
|
# The queue should wait a reactor tick before calling the processing
|
||||||
|
# function.
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
self.assertFalse(queue_d.called)
|
||||||
|
|
||||||
|
# We should see a call to `_process_queue` after a reactor tick.
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo"])
|
||||||
|
self.assertFalse(queue_d.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back.
|
||||||
|
self._pending_calls.pop()[1].callback("bar")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d), "bar")
|
||||||
|
|
||||||
|
def test_batching(self):
|
||||||
|
"""Test that multiple calls at the same time get batched up into one
|
||||||
|
call to `_process_queue`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||||
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||||
|
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
# We should see only *one* call to `_process_queue`
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
|
||||||
|
self.assertFalse(queue_d1.called)
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to both.
|
||||||
|
self._pending_calls.pop()[1].callback("bar")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d1), "bar")
|
||||||
|
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
||||||
|
|
||||||
|
def test_queuing(self):
|
||||||
|
"""Test that we queue up requests while a `_process_queue` is being
|
||||||
|
called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||||
|
|
||||||
|
# We should see only *one* call to `_process_queue`
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||||
|
self.assertFalse(queue_d1.called)
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# first.
|
||||||
|
self._pending_calls.pop()[1].callback("bar1")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# We should now see a second call to `_process_queue`
|
||||||
|
self.clock.pump([0])
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo2"])
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# second.
|
||||||
|
self._pending_calls.pop()[1].callback("bar2")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||||
|
|
||||||
|
def test_different_keys(self):
|
||||||
|
"""Test that calls to different keys get processed in parallel."""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
|
||||||
|
self.clock.pump([0])
|
||||||
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
# We queue up another item with key=2 to check that we will keep taking
|
||||||
|
# things off the queue.
|
||||||
|
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2))
|
||||||
|
|
||||||
|
# We should see two calls to `_process_queue`
|
||||||
|
self.assertEqual(len(self._pending_calls), 2)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||||
|
self.assertEqual(self._pending_calls[1][0], ["foo2"])
|
||||||
|
self.assertFalse(queue_d1.called)
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# first.
|
||||||
|
self._pending_calls.pop(0)[1].callback("bar1")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# second.
|
||||||
|
self._pending_calls.pop()[1].callback("bar2")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# We should now see a call `_pending_calls` for `foo3`
|
||||||
|
self.clock.pump([0])
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo3"])
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# third deferred.
|
||||||
|
self._pending_calls.pop()[1].callback("bar4")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d3), "bar4")
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, List
|
from typing import Dict, Iterable, List, Sequence
|
||||||
|
|
||||||
from synapse.util.iterutils import chunk_seq, sorted_topologically
|
from synapse.util.iterutils import chunk_seq, sorted_topologically
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_empty_input(self):
|
def test_empty_input(self):
|
||||||
parts = chunk_seq([], 5)
|
parts = chunk_seq([], 5) # type: Iterable[Sequence]
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
list(parts),
|
list(parts),
|
||||||
|
|
|
@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(cache.pop("key"), None)
|
self.assertEquals(cache.pop("key"), None)
|
||||||
|
|
||||||
def test_del_multi(self):
|
def test_del_multi(self):
|
||||||
cache = LruCache(4, keylen=2, cache_type=TreeCache)
|
cache = LruCache(4, cache_type=TreeCache)
|
||||||
cache[("animal", "cat")] = "mew"
|
cache[("animal", "cat")] = "mew"
|
||||||
cache[("animal", "dog")] = "woof"
|
cache[("animal", "dog")] = "woof"
|
||||||
cache[("vehicles", "car")] = "vroom"
|
cache[("vehicles", "car")] = "vroom"
|
||||||
|
@ -165,7 +165,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||||
m2 = Mock()
|
m2 = Mock()
|
||||||
m3 = Mock()
|
m3 = Mock()
|
||||||
m4 = Mock()
|
m4 = Mock()
|
||||||
cache = LruCache(4, keylen=2, cache_type=TreeCache)
|
cache = LruCache(4, cache_type=TreeCache)
|
||||||
|
|
||||||
cache.set(("a", "1"), "value", callbacks=[m1])
|
cache.set(("a", "1"), "value", callbacks=[m1])
|
||||||
cache.set(("a", "2"), "value", callbacks=[m2])
|
cache.set(("a", "2"), "value", callbacks=[m2])
|
||||||
|
|
|
@ -51,10 +51,12 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||||
self.assertEqual(new_timings["failure_ts"], failure_ts)
|
self.assertEqual(new_timings.failure_ts, failure_ts)
|
||||||
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
|
self.assertEqual(new_timings.retry_last_ts, failure_ts)
|
||||||
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
|
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
|
||||||
|
|
||||||
# now if we try again we should get a failure
|
# now if we try again we should get a failure
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
@ -77,14 +79,16 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||||
self.assertEqual(new_timings["failure_ts"], failure_ts)
|
self.assertEqual(new_timings.failure_ts, failure_ts)
|
||||||
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
|
self.assertEqual(new_timings.retry_last_ts, retry_ts)
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
|
new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
|
||||||
)
|
)
|
||||||
self.assertLessEqual(
|
self.assertLessEqual(
|
||||||
new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
|
new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
|
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
|
@ -64,12 +64,14 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||||
cache[("a", "b")] = "AB"
|
cache[("a", "b")] = "AB"
|
||||||
cache[("b", "a")] = "BA"
|
cache[("b", "a")] = "BA"
|
||||||
self.assertEquals(cache.get(("a", "a")), "AA")
|
self.assertEquals(cache.get(("a", "a")), "AA")
|
||||||
cache.pop(("a",))
|
popped = cache.pop(("a",))
|
||||||
self.assertEquals(cache.get(("a", "a")), None)
|
self.assertEquals(cache.get(("a", "a")), None)
|
||||||
self.assertEquals(cache.get(("a", "b")), None)
|
self.assertEquals(cache.get(("a", "b")), None)
|
||||||
self.assertEquals(cache.get(("b", "a")), "BA")
|
self.assertEquals(cache.get(("b", "a")), "BA")
|
||||||
self.assertEquals(len(cache), 1)
|
self.assertEquals(len(cache), 1)
|
||||||
|
|
||||||
|
self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
|
||||||
|
|
||||||
def test_clear(self):
|
def test_clear(self):
|
||||||
cache = TreeCache()
|
cache = TreeCache()
|
||||||
cache[("a",)] = "A"
|
cache[("a",)] = "A"
|
||||||
|
|
10
tox.ini
10
tox.ini
|
@ -34,7 +34,17 @@ lint_targets =
|
||||||
synapse
|
synapse
|
||||||
tests
|
tests
|
||||||
scripts
|
scripts
|
||||||
|
# annoyingly, black doesn't find these so we have to list them
|
||||||
|
scripts/export_signing_key
|
||||||
|
scripts/generate_config
|
||||||
|
scripts/generate_log_config
|
||||||
|
scripts/hash_password
|
||||||
|
scripts/register_new_matrix_user
|
||||||
|
scripts/synapse_port_db
|
||||||
scripts-dev
|
scripts-dev
|
||||||
|
scripts-dev/build_debian_packages
|
||||||
|
scripts-dev/sign_json
|
||||||
|
scripts-dev/update_database
|
||||||
stubs
|
stubs
|
||||||
contrib
|
contrib
|
||||||
synctl
|
synctl
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue