mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-04-24 21:29:16 -04:00
Merge remote-tracking branch 'upstream/release-v1.35'
This commit is contained in:
commit
4740b83c39
@ -3,7 +3,7 @@
|
||||
# CI's Docker setup at the point where this file is considered.
|
||||
server_name: "localhost:8800"
|
||||
|
||||
signing_key_path: "/src/.buildkite/test.signing.key"
|
||||
signing_key_path: ".buildkite/test.signing.key"
|
||||
|
||||
report_stats: false
|
||||
|
||||
@ -16,6 +16,4 @@ database:
|
||||
database: synapse
|
||||
|
||||
# Suppress the key server warning.
|
||||
trusted_key_servers:
|
||||
- server_name: "matrix.org"
|
||||
suppress_key_server_warning: true
|
||||
trusted_key_servers: []
|
||||
|
@ -33,6 +33,10 @@ scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml
|
||||
echo "+++ Run synapse_port_db against test database"
|
||||
coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
|
||||
|
||||
# We should be able to run twice against the same database.
|
||||
echo "+++ Run synapse_port_db a second time"
|
||||
coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
|
||||
|
||||
#####
|
||||
|
||||
# Now do the same again, on an empty database.
|
||||
|
@ -3,7 +3,7 @@
|
||||
# schema and run background updates on it.
|
||||
server_name: "localhost:8800"
|
||||
|
||||
signing_key_path: "/src/.buildkite/test.signing.key"
|
||||
signing_key_path: ".buildkite/test.signing.key"
|
||||
|
||||
report_stats: false
|
||||
|
||||
@ -13,6 +13,4 @@ database:
|
||||
database: ".buildkite/test_db.db"
|
||||
|
||||
# Suppress the key server warning.
|
||||
trusted_key_servers:
|
||||
- server_name: "matrix.org"
|
||||
suppress_key_server_warning: true
|
||||
trusted_key_servers: []
|
||||
|
64
CHANGES.md
64
CHANGES.md
@ -1,3 +1,67 @@
|
||||
Synapse 1.35.0rc1 (2021-05-25)
|
||||
==============================
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- Add experimental support to allow a user who could join a restricted room to view it in the spaces summary. ([\#9922](https://github.com/matrix-org/synapse/issues/9922), [\#10007](https://github.com/matrix-org/synapse/issues/10007), [\#10038](https://github.com/matrix-org/synapse/issues/10038))
|
||||
- Reduce memory usage when joining very large rooms over federation. ([\#9958](https://github.com/matrix-org/synapse/issues/9958))
|
||||
- Add a configuration option which allows enabling opentracing by user id. ([\#9978](https://github.com/matrix-org/synapse/issues/9978))
|
||||
- Enable experimental support for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946) (spaces summary API) and [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) (restricted join rules) by default. ([\#10011](https://github.com/matrix-org/synapse/issues/10011))
|
||||
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug introduced in v1.26.0 which meant that `synapse_port_db` would not correctly initialise some postgres sequences, requiring manual updates afterwards. ([\#9991](https://github.com/matrix-org/synapse/issues/9991))
|
||||
- Fix `synctl`'s `--no-daemonize` parameter to work correctly with worker processes. ([\#9995](https://github.com/matrix-org/synapse/issues/9995))
|
||||
- Fix a validation bug introduced in v1.34.0 in the ordering of spaces in the space summary API. ([\#10002](https://github.com/matrix-org/synapse/issues/10002))
|
||||
- Fixed deletion of new presence stream states from database. ([\#10014](https://github.com/matrix-org/synapse/issues/10014), [\#10033](https://github.com/matrix-org/synapse/issues/10033))
|
||||
- Fixed a bug with very high resolution image uploads throwing internal server errors. ([\#10029](https://github.com/matrix-org/synapse/issues/10029))
|
||||
|
||||
|
||||
Updates to the Docker image
|
||||
---------------------------
|
||||
|
||||
- Fix bug introduced in Synapse 1.33.0 which caused a `Permission denied: '/homeserver.log'` error when starting Synapse with the generated log configuration. Contributed by Sergio Miguéns Iglesias. ([\#10045](https://github.com/matrix-org/synapse/issues/10045))
|
||||
|
||||
|
||||
Improved Documentation
|
||||
----------------------
|
||||
|
||||
- Add hardened systemd files as proposed in [#9760](https://github.com/matrix-org/synapse/issues/9760) and added them to `contrib/`. Change the docs to reflect the presence of these files. ([\#9803](https://github.com/matrix-org/synapse/issues/9803))
|
||||
- Clarify documentation around SSO mapping providers generating unique IDs and localparts. ([\#9980](https://github.com/matrix-org/synapse/issues/9980))
|
||||
- Updates to the PostgreSQL documentation (`postgres.md`). ([\#9988](https://github.com/matrix-org/synapse/issues/9988), [\#9989](https://github.com/matrix-org/synapse/issues/9989))
|
||||
- Fix broken link in user directory documentation. Contributed by @junquera. ([\#10016](https://github.com/matrix-org/synapse/issues/10016))
|
||||
- Add missing room state entry to the table of contents of room admin API. ([\#10043](https://github.com/matrix-org/synapse/issues/10043))
|
||||
|
||||
|
||||
Deprecations and Removals
|
||||
-------------------------
|
||||
|
||||
- Removed support for the deprecated `tls_fingerprints` configuration setting. Contributed by Jerin J Titus. ([\#9280](https://github.com/matrix-org/synapse/issues/9280))
|
||||
|
||||
|
||||
Internal Changes
|
||||
----------------
|
||||
|
||||
- Allow sending full presence to users via workers other than the one that called `ModuleApi.send_local_online_presence_to`. ([\#9823](https://github.com/matrix-org/synapse/issues/9823))
|
||||
- Update comments in the space summary handler. ([\#9974](https://github.com/matrix-org/synapse/issues/9974))
|
||||
- Minor enhancements to the `@cachedList` descriptor. ([\#9975](https://github.com/matrix-org/synapse/issues/9975))
|
||||
- Split multipart email sending into a dedicated handler. ([\#9977](https://github.com/matrix-org/synapse/issues/9977))
|
||||
- Run `black` on files in the `scripts` directory. ([\#9981](https://github.com/matrix-org/synapse/issues/9981))
|
||||
- Add missing type hints to `synapse.util` module. ([\#9982](https://github.com/matrix-org/synapse/issues/9982))
|
||||
- Simplify a few helper functions. ([\#9984](https://github.com/matrix-org/synapse/issues/9984), [\#9985](https://github.com/matrix-org/synapse/issues/9985), [\#9986](https://github.com/matrix-org/synapse/issues/9986))
|
||||
- Remove unnecessary property from SQLBaseStore. ([\#9987](https://github.com/matrix-org/synapse/issues/9987))
|
||||
- Remove `keylen` param on `LruCache`. ([\#9993](https://github.com/matrix-org/synapse/issues/9993))
|
||||
- Update the Grafana dashboard in `contrib/`. ([\#10001](https://github.com/matrix-org/synapse/issues/10001))
|
||||
- Add a batching queue implementation. ([\#10017](https://github.com/matrix-org/synapse/issues/10017))
|
||||
- Reduce memory usage when verifying signatures on large numbers of events at once. ([\#10018](https://github.com/matrix-org/synapse/issues/10018))
|
||||
- Properly invalidate caches for destination retry timings every (instead of expiring entries every 5 minutes). ([\#10036](https://github.com/matrix-org/synapse/issues/10036))
|
||||
- Fix running complement tests with Synapse workers. ([\#10039](https://github.com/matrix-org/synapse/issues/10039))
|
||||
- Fix typo in `get_state_ids_for_event` docstring where the return type was incorrect. ([\#10050](https://github.com/matrix-org/synapse/issues/10050))
|
||||
|
||||
|
||||
Synapse 1.34.0 (2021-05-17)
|
||||
===========================
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
71
contrib/systemd/override-hardened.conf
Normal file
71
contrib/systemd/override-hardened.conf
Normal file
@ -0,0 +1,71 @@
|
||||
[Service]
|
||||
# The following directives give the synapse service R/W access to:
|
||||
# - /run/matrix-synapse
|
||||
# - /var/lib/matrix-synapse
|
||||
# - /var/log/matrix-synapse
|
||||
|
||||
RuntimeDirectory=matrix-synapse
|
||||
StateDirectory=matrix-synapse
|
||||
LogsDirectory=matrix-synapse
|
||||
|
||||
######################
|
||||
## Security Sandbox ##
|
||||
######################
|
||||
|
||||
# Make sure that the service has its own unshared tmpfs at /tmp and that it
|
||||
# cannot see or change any real devices
|
||||
PrivateTmp=true
|
||||
PrivateDevices=true
|
||||
|
||||
# We give no capabilities to a service by default
|
||||
CapabilityBoundingSet=
|
||||
AmbientCapabilities=
|
||||
|
||||
# Protect the following from modification:
|
||||
# - The entire filesystem
|
||||
# - sysctl settings and loaded kernel modules
|
||||
# - No modifications allowed to Control Groups
|
||||
# - Hostname
|
||||
# - System Clock
|
||||
ProtectSystem=strict
|
||||
ProtectKernelTunables=true
|
||||
ProtectKernelModules=true
|
||||
ProtectControlGroups=true
|
||||
ProtectClock=true
|
||||
ProtectHostname=true
|
||||
|
||||
# Prevent access to the following:
|
||||
# - /home directory
|
||||
# - Kernel logs
|
||||
ProtectHome=tmpfs
|
||||
ProtectKernelLogs=true
|
||||
|
||||
# Make sure that the process can only see PIDs and process details of itself,
|
||||
# and the second option disables seeing details of things like system load and
|
||||
# I/O etc
|
||||
ProtectProc=invisible
|
||||
ProcSubset=pid
|
||||
|
||||
# While not needed, we set these options explicitly
|
||||
# - This process has been given access to the host network
|
||||
# - It can also communicate with any IP Address
|
||||
PrivateNetwork=false
|
||||
RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX
|
||||
IPAddressAllow=any
|
||||
|
||||
# Restrict system calls to a sane bunch
|
||||
SystemCallArchitectures=native
|
||||
SystemCallFilter=@system-service
|
||||
SystemCallFilter=~@privileged @resources @obsolete
|
||||
|
||||
# Misc restrictions
|
||||
# - Since the process is a python process it needs to be able to write and
|
||||
# execute memory regions, so we set MemoryDenyWriteExecute to false
|
||||
RestrictSUIDSGID=true
|
||||
RemoveIPC=true
|
||||
NoNewPrivileges=true
|
||||
RestrictRealtime=true
|
||||
RestrictNamespaces=true
|
||||
LockPersonality=true
|
||||
PrivateUsers=true
|
||||
MemoryDenyWriteExecute=false
|
@ -9,10 +9,11 @@ formatters:
|
||||
{% endif %}
|
||||
|
||||
handlers:
|
||||
{% if LOG_FILE_PATH %}
|
||||
file:
|
||||
class: logging.handlers.TimedRotatingFileHandler
|
||||
formatter: precise
|
||||
filename: {{ LOG_FILE_PATH or "homeserver.log" }}
|
||||
filename: {{ LOG_FILE_PATH }}
|
||||
when: "midnight"
|
||||
backupCount: 6 # Does not include the current log file.
|
||||
encoding: utf8
|
||||
@ -29,6 +30,7 @@ handlers:
|
||||
# be written to disk.
|
||||
capacity: 10
|
||||
flushLevel: 30 # Flush for WARNING logs as well
|
||||
{% endif %}
|
||||
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
|
@ -184,18 +184,18 @@ stderr_logfile_maxbytes=0
|
||||
"""
|
||||
|
||||
NGINX_LOCATION_CONFIG_BLOCK = """
|
||||
location ~* {endpoint} {
|
||||
location ~* {endpoint} {{
|
||||
proxy_pass {upstream};
|
||||
proxy_set_header X-Forwarded-For $remote_addr;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
}}
|
||||
"""
|
||||
|
||||
NGINX_UPSTREAM_CONFIG_BLOCK = """
|
||||
upstream {upstream_worker_type} {
|
||||
upstream {upstream_worker_type} {{
|
||||
{body}
|
||||
}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
* [Usage](#usage)
|
||||
- [Room Details API](#room-details-api)
|
||||
- [Room Members API](#room-members-api)
|
||||
- [Room State API](#room-state-api)
|
||||
- [Delete Room API](#delete-room-api)
|
||||
* [Parameters](#parameters-1)
|
||||
* [Response](#response)
|
||||
|
@ -42,17 +42,17 @@ To receive OpenTracing spans, start up a Jaeger server. This can be done
|
||||
using docker like so:
|
||||
|
||||
```sh
|
||||
docker run -d --name jaeger
|
||||
docker run -d --name jaeger \
|
||||
-p 6831:6831/udp \
|
||||
-p 6832:6832/udp \
|
||||
-p 5778:5778 \
|
||||
-p 16686:16686 \
|
||||
-p 14268:14268 \
|
||||
jaegertracing/all-in-one:1.13
|
||||
jaegertracing/all-in-one:1
|
||||
```
|
||||
|
||||
Latest documentation is probably at
|
||||
<https://www.jaegertracing.io/docs/1.13/getting-started/>
|
||||
https://www.jaegertracing.io/docs/latest/getting-started.
|
||||
|
||||
## Enable OpenTracing in Synapse
|
||||
|
||||
@ -62,7 +62,7 @@ as shown in the [sample config](./sample_config.yaml). For example:
|
||||
|
||||
```yaml
|
||||
opentracing:
|
||||
tracer_enabled: true
|
||||
enabled: true
|
||||
homeserver_whitelist:
|
||||
- "mytrustedhomeserver.org"
|
||||
- "*.myotherhomeservers.com"
|
||||
@ -90,4 +90,4 @@ to two problems, namely:
|
||||
## Configuring Jaeger
|
||||
|
||||
Sampling strategies can be set as in this document:
|
||||
<https://www.jaegertracing.io/docs/1.13/sampling/>
|
||||
<https://www.jaegertracing.io/docs/latest/sampling/>.
|
||||
|
200
docs/postgres.md
200
docs/postgres.md
@ -1,6 +1,6 @@
|
||||
# Using Postgres
|
||||
|
||||
Postgres version 9.5 or later is known to work.
|
||||
Synapse supports PostgreSQL versions 9.6 or later.
|
||||
|
||||
## Install postgres client libraries
|
||||
|
||||
@ -33,28 +33,15 @@ Assuming your PostgreSQL database user is called `postgres`, first authenticate
|
||||
# Or, if your system uses sudo to get administrative rights
|
||||
sudo -u postgres bash
|
||||
|
||||
Then, create a user ``synapse_user`` with:
|
||||
Then, create a postgres user and a database with:
|
||||
|
||||
# this will prompt for a password for the new user
|
||||
createuser --pwprompt synapse_user
|
||||
|
||||
Before you can authenticate with the `synapse_user`, you must create a
|
||||
database that it can access. To create a database, first connect to the
|
||||
database with your database user:
|
||||
createdb --encoding=UTF8 --locale=C --template=template0 --owner=synapse_user synapse
|
||||
|
||||
su - postgres # Or: sudo -u postgres bash
|
||||
psql
|
||||
|
||||
and then run:
|
||||
|
||||
CREATE DATABASE synapse
|
||||
ENCODING 'UTF8'
|
||||
LC_COLLATE='C'
|
||||
LC_CTYPE='C'
|
||||
template=template0
|
||||
OWNER synapse_user;
|
||||
|
||||
This would create an appropriate database named `synapse` owned by the
|
||||
`synapse_user` user (which must already have been created as above).
|
||||
The above will create a user called `synapse_user`, and a database called
|
||||
`synapse`.
|
||||
|
||||
Note that the PostgreSQL database *must* have the correct encoding set
|
||||
(as shown above), otherwise it will not be able to store UTF8 strings.
|
||||
@ -63,79 +50,6 @@ You may need to enable password authentication so `synapse_user` can
|
||||
connect to the database. See
|
||||
<https://www.postgresql.org/docs/current/auth-pg-hba-conf.html>.
|
||||
|
||||
If you get an error along the lines of `FATAL: Ident authentication failed for
|
||||
user "synapse_user"`, you may need to use an authentication method other than
|
||||
`ident`:
|
||||
|
||||
* If the `synapse_user` user has a password, add the password to the `database:`
|
||||
section of `homeserver.yaml`. Then add the following to `pg_hba.conf`:
|
||||
|
||||
```
|
||||
host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that
|
||||
```
|
||||
|
||||
* If the `synapse_user` user does not have a password, then a password doesn't
|
||||
have to be added to `homeserver.yaml`. But the following does need to be added
|
||||
to `pg_hba.conf`:
|
||||
|
||||
```
|
||||
host synapse synapse_user ::1/128 trust
|
||||
```
|
||||
|
||||
Note that line order matters in `pg_hba.conf`, so make sure that if you do add a
|
||||
new line, it is inserted before:
|
||||
|
||||
```
|
||||
host all all ::1/128 ident
|
||||
```
|
||||
|
||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||
|
||||
Synapse will refuse to set up a new database if it has the wrong values of
|
||||
`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
|
||||
different locales can cause issues if the locale library is updated from
|
||||
underneath the database, or if a different version of the locale is used on any
|
||||
replicas.
|
||||
|
||||
The safest way to fix the issue is to take a dump and recreate the database with
|
||||
the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the
|
||||
parameters on a live database and run a `REINDEX` on the entire database,
|
||||
however extreme care must be taken to avoid database corruption.
|
||||
|
||||
Note that the above may fail with an error about duplicate rows if corruption
|
||||
has already occurred, and such duplicate rows will need to be manually removed.
|
||||
|
||||
|
||||
## Fixing inconsistent sequences error
|
||||
|
||||
Synapse uses Postgres sequences to generate IDs for various tables. A sequence
|
||||
and associated table can get out of sync if, for example, Synapse has been
|
||||
downgraded and then upgraded again.
|
||||
|
||||
To fix the issue shut down Synapse (including any and all workers) and run the
|
||||
SQL command included in the error message. Once done Synapse should start
|
||||
successfully.
|
||||
|
||||
|
||||
## Tuning Postgres
|
||||
|
||||
The default settings should be fine for most deployments. For larger
|
||||
scale deployments tuning some of the settings is recommended, details of
|
||||
which can be found at
|
||||
<https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server>.
|
||||
|
||||
In particular, we've found tuning the following values helpful for
|
||||
performance:
|
||||
|
||||
- `shared_buffers`
|
||||
- `effective_cache_size`
|
||||
- `work_mem`
|
||||
- `maintenance_work_mem`
|
||||
- `autovacuum_work_mem`
|
||||
|
||||
Note that the appropriate values for those fields depend on the amount
|
||||
of free memory the database host has available.
|
||||
|
||||
## Synapse config
|
||||
|
||||
When you are ready to start using PostgreSQL, edit the `database`
|
||||
@ -165,18 +79,42 @@ may block for an extended period while it waits for a response from the
|
||||
database server. Example values might be:
|
||||
|
||||
```yaml
|
||||
# seconds of inactivity after which TCP should send a keepalive message to the server
|
||||
keepalives_idle: 10
|
||||
database:
|
||||
args:
|
||||
# ... as above
|
||||
|
||||
# the number of seconds after which a TCP keepalive message that is not
|
||||
# acknowledged by the server should be retransmitted
|
||||
keepalives_interval: 10
|
||||
# seconds of inactivity after which TCP should send a keepalive message to the server
|
||||
keepalives_idle: 10
|
||||
|
||||
# the number of TCP keepalives that can be lost before the client's connection
|
||||
# to the server is considered dead
|
||||
keepalives_count: 3
|
||||
# the number of seconds after which a TCP keepalive message that is not
|
||||
# acknowledged by the server should be retransmitted
|
||||
keepalives_interval: 10
|
||||
|
||||
# the number of TCP keepalives that can be lost before the client's connection
|
||||
# to the server is considered dead
|
||||
keepalives_count: 3
|
||||
```
|
||||
|
||||
## Tuning Postgres
|
||||
|
||||
The default settings should be fine for most deployments. For larger
|
||||
scale deployments tuning some of the settings is recommended, details of
|
||||
which can be found at
|
||||
<https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server>.
|
||||
|
||||
In particular, we've found tuning the following values helpful for
|
||||
performance:
|
||||
|
||||
- `shared_buffers`
|
||||
- `effective_cache_size`
|
||||
- `work_mem`
|
||||
- `maintenance_work_mem`
|
||||
- `autovacuum_work_mem`
|
||||
|
||||
Note that the appropriate values for those fields depend on the amount
|
||||
of free memory the database host has available.
|
||||
|
||||
|
||||
## Porting from SQLite
|
||||
|
||||
### Overview
|
||||
@ -185,9 +123,8 @@ The script `synapse_port_db` allows porting an existing synapse server
|
||||
backed by SQLite to using PostgreSQL. This is done in as a two phase
|
||||
process:
|
||||
|
||||
1. Copy the existing SQLite database to a separate location (while the
|
||||
server is down) and running the port script against that offline
|
||||
database.
|
||||
1. Copy the existing SQLite database to a separate location and run
|
||||
the port script against that offline database.
|
||||
2. Shut down the server. Rerun the port script to port any data that
|
||||
has come in since taking the first snapshot. Restart server against
|
||||
the PostgreSQL database.
|
||||
@ -245,3 +182,60 @@ PostgreSQL database configuration file `homeserver-postgres.yaml`:
|
||||
./synctl start
|
||||
|
||||
Synapse should now be running against PostgreSQL.
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Alternative auth methods
|
||||
|
||||
If you get an error along the lines of `FATAL: Ident authentication failed for
|
||||
user "synapse_user"`, you may need to use an authentication method other than
|
||||
`ident`:
|
||||
|
||||
* If the `synapse_user` user has a password, add the password to the `database:`
|
||||
section of `homeserver.yaml`. Then add the following to `pg_hba.conf`:
|
||||
|
||||
```
|
||||
host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that
|
||||
```
|
||||
|
||||
* If the `synapse_user` user does not have a password, then a password doesn't
|
||||
have to be added to `homeserver.yaml`. But the following does need to be added
|
||||
to `pg_hba.conf`:
|
||||
|
||||
```
|
||||
host synapse synapse_user ::1/128 trust
|
||||
```
|
||||
|
||||
Note that line order matters in `pg_hba.conf`, so make sure that if you do add a
|
||||
new line, it is inserted before:
|
||||
|
||||
```
|
||||
host all all ::1/128 ident
|
||||
```
|
||||
|
||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||
|
||||
Synapse will refuse to set up a new database if it has the wrong values of
|
||||
`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
|
||||
different locales can cause issues if the locale library is updated from
|
||||
underneath the database, or if a different version of the locale is used on any
|
||||
replicas.
|
||||
|
||||
The safest way to fix the issue is to dump the database and recreate it with
|
||||
the correct locale parameter (as shown above). It is also possible to change the
|
||||
parameters on a live database and run a `REINDEX` on the entire database,
|
||||
however extreme care must be taken to avoid database corruption.
|
||||
|
||||
Note that the above may fail with an error about duplicate rows if corruption
|
||||
has already occurred, and such duplicate rows will need to be manually removed.
|
||||
|
||||
### Fixing inconsistent sequences error
|
||||
|
||||
Synapse uses Postgres sequences to generate IDs for various tables. A sequence
|
||||
and associated table can get out of sync if, for example, Synapse has been
|
||||
downgraded and then upgraded again.
|
||||
|
||||
To fix the issue shut down Synapse (including any and all workers) and run the
|
||||
SQL command included in the error message. Once done Synapse should start
|
||||
successfully.
|
||||
|
@ -28,7 +28,11 @@ async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None
|
||||
which can be given a list of local or remote MXIDs to broadcast known, online user
|
||||
presence to (for those users that the receiving user is considered interested in).
|
||||
It does not include state for users who are currently offline, and it can only be
|
||||
called on workers that support sending federation.
|
||||
called on workers that support sending federation. Additionally, this method must
|
||||
only be called from the process that has been configured to write to the
|
||||
the [presence stream](https://github.com/matrix-org/synapse/blob/master/docs/workers.md#stream-writers).
|
||||
By default, this is the main process, but another worker can be configured to do
|
||||
so.
|
||||
|
||||
### Module structure
|
||||
|
||||
|
@ -683,33 +683,6 @@ acme:
|
||||
#
|
||||
account_key_file: DATADIR/acme_account.key
|
||||
|
||||
# List of allowed TLS fingerprints for this server to publish along
|
||||
# with the signing keys for this server. Other matrix servers that
|
||||
# make HTTPS requests to this server will check that the TLS
|
||||
# certificates returned by this server match one of the fingerprints.
|
||||
#
|
||||
# Synapse automatically adds the fingerprint of its own certificate
|
||||
# to the list. So if federation traffic is handled directly by synapse
|
||||
# then no modification to the list is required.
|
||||
#
|
||||
# If synapse is run behind a load balancer that handles the TLS then it
|
||||
# will be necessary to add the fingerprints of the certificates used by
|
||||
# the loadbalancers to this list if they are different to the one
|
||||
# synapse is using.
|
||||
#
|
||||
# Homeservers are permitted to cache the list of TLS fingerprints
|
||||
# returned in the key responses up to the "valid_until_ts" returned in
|
||||
# key. It may be necessary to publish the fingerprints of a new
|
||||
# certificate and wait until the "valid_until_ts" of the previous key
|
||||
# responses have passed before deploying it.
|
||||
#
|
||||
# You can calculate a fingerprint from a given TLS listener via:
|
||||
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
||||
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
||||
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
||||
#
|
||||
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||
|
||||
|
||||
## Federation ##
|
||||
|
||||
@ -2845,7 +2818,8 @@ opentracing:
|
||||
#enabled: true
|
||||
|
||||
# The list of homeservers we wish to send and receive span contexts and span baggage.
|
||||
# See docs/opentracing.rst
|
||||
# See docs/opentracing.rst.
|
||||
#
|
||||
# This is a list of regexes which are matched against the server_name of the
|
||||
# homeserver.
|
||||
#
|
||||
@ -2854,19 +2828,26 @@ opentracing:
|
||||
#homeserver_whitelist:
|
||||
# - ".*"
|
||||
|
||||
# A list of the matrix IDs of users whose requests will always be traced,
|
||||
# even if the tracing system would otherwise drop the traces due to
|
||||
# probabilistic sampling.
|
||||
#
|
||||
# By default, the list is empty.
|
||||
#
|
||||
#force_tracing_for_users:
|
||||
# - "@user1:server_name"
|
||||
# - "@user2:server_name"
|
||||
|
||||
# Jaeger can be configured to sample traces at different rates.
|
||||
# All configuration options provided by Jaeger can be set here.
|
||||
# Jaeger's configuration mostly related to trace sampling which
|
||||
# Jaeger's configuration is mostly related to trace sampling which
|
||||
# is documented here:
|
||||
# https://www.jaegertracing.io/docs/1.13/sampling/.
|
||||
# https://www.jaegertracing.io/docs/latest/sampling/.
|
||||
#
|
||||
#jaeger_config:
|
||||
# sampler:
|
||||
# type: const
|
||||
# param: 1
|
||||
|
||||
# Logging whether spans were started and reported
|
||||
#
|
||||
# logging:
|
||||
# false
|
||||
|
||||
@ -2935,3 +2916,18 @@ redis:
|
||||
# Optional password if configured on the Redis instance
|
||||
#
|
||||
#password: <secret_password>
|
||||
|
||||
|
||||
# Enable experimental features in Synapse.
|
||||
#
|
||||
# Experimental features might break or be removed without a deprecation
|
||||
# period.
|
||||
#
|
||||
experimental_features:
|
||||
# Support for Spaces (MSC1772), it enables the following:
|
||||
#
|
||||
# * The Spaces Summary API (MSC2946).
|
||||
# * Restricting room membership based on space membership (MSC3083).
|
||||
#
|
||||
# Uncomment to disable support for Spaces.
|
||||
#spaces_enabled: false
|
||||
|
@ -67,8 +67,8 @@ A custom mapping provider must specify the following methods:
|
||||
- Arguments:
|
||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||
information from.
|
||||
- This method must return a string, which is the unique identifier for the
|
||||
user. Commonly the ``sub`` claim of the response.
|
||||
- This method must return a string, which is the unique, immutable identifier
|
||||
for the user. Commonly the `sub` claim of the response.
|
||||
* `map_user_attributes(self, userinfo, token, failures)`
|
||||
- This method must be async.
|
||||
- Arguments:
|
||||
@ -87,7 +87,9 @@ A custom mapping provider must specify the following methods:
|
||||
`localpart` value, such as `john.doe1`.
|
||||
- Returns a dictionary with two keys:
|
||||
- `localpart`: A string, used to generate the Matrix ID. If this is
|
||||
`None`, the user is prompted to pick their own username.
|
||||
`None`, the user is prompted to pick their own username. This is only used
|
||||
during a user's first login. Once a localpart has been associated with a
|
||||
remote user ID (see `get_remote_user_id`) it cannot be updated.
|
||||
- `displayname`: An optional string, the display name for the user.
|
||||
* `get_extra_attributes(self, userinfo, token)`
|
||||
- This method must be async.
|
||||
@ -153,8 +155,8 @@ A custom mapping provider must specify the following methods:
|
||||
information from.
|
||||
- `client_redirect_url` - A string, the URL that the client will be
|
||||
redirected to.
|
||||
- This method must return a string, which is the unique identifier for the
|
||||
user. Commonly the ``uid`` claim of the response.
|
||||
- This method must return a string, which is the unique, immutable identifier
|
||||
for the user. Commonly the `uid` claim of the response.
|
||||
* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
|
||||
- Arguments:
|
||||
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
||||
@ -172,8 +174,10 @@ A custom mapping provider must specify the following methods:
|
||||
redirected to.
|
||||
- This method must return a dictionary, which will then be used by Synapse
|
||||
to build a new user. The following keys are allowed:
|
||||
* `mxid_localpart` - The mxid localpart of the new user. If this is
|
||||
`None`, the user is prompted to pick their own username.
|
||||
* `mxid_localpart` - A string, the mxid localpart of the new user. If this is
|
||||
`None`, the user is prompted to pick their own username. This is only used
|
||||
during a user's first login. Once a localpart has been associated with a
|
||||
remote user ID (see `get_remote_user_id`) it cannot be updated.
|
||||
* `displayname` - The displayname of the new user. If not provided, will default to
|
||||
the value of `mxid_localpart`.
|
||||
* `emails` - A list of emails for the new user. If not provided, will
|
||||
|
@ -65,3 +65,33 @@ systemctl restart matrix-synapse-worker@federation_reader.service
|
||||
systemctl enable matrix-synapse-worker@federation_writer.service
|
||||
systemctl restart matrix-synapse.target
|
||||
```
|
||||
|
||||
## Hardening
|
||||
|
||||
**Optional:** If further hardening is desired, the file
|
||||
`override-hardened.conf` may be copied from
|
||||
`contrib/systemd/override-hardened.conf` in this repository to the location
|
||||
`/etc/systemd/system/matrix-synapse.service.d/override-hardened.conf` (the
|
||||
directory may have to be created). It enables certain sandboxing features in
|
||||
systemd to further secure the synapse service. You may read the comments to
|
||||
understand what the override file is doing. The same file will need to be copied
|
||||
to
|
||||
`/etc/systemd/system/matrix-synapse-worker@.service.d/override-hardened-worker.conf`
|
||||
(this directory may also have to be created) in order to apply the same
|
||||
hardening options to any worker processes.
|
||||
|
||||
Once these files have been copied to their appropriate locations, simply reload
|
||||
systemd's manager config files and restart all Synapse services to apply the hardening options. They will automatically
|
||||
be applied at every restart as long as the override files are present at the
|
||||
specified locations.
|
||||
|
||||
```sh
|
||||
systemctl daemon-reload
|
||||
|
||||
# Restart services
|
||||
systemctl restart matrix-synapse.target
|
||||
```
|
||||
|
||||
In order to see their effect, you may run `systemd-analyze security
|
||||
matrix-synapse.service` before and after applying the hardening options to see
|
||||
the changes being applied at a glance.
|
||||
|
@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server.
|
||||
|
||||
The directory info is stored in various tables, which can (typically after
|
||||
DB corruption) get stale or out of sync. If this happens, for now the
|
||||
solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql)
|
||||
solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql)
|
||||
and then restart synapse. This should then start a background task to
|
||||
flush the current tables and regenerate the directory.
|
||||
|
12
mypy.ini
12
mypy.ini
@ -71,8 +71,13 @@ files =
|
||||
synapse/types.py,
|
||||
synapse/util/async_helpers.py,
|
||||
synapse/util/caches,
|
||||
synapse/util/daemonize.py,
|
||||
synapse/util/hash.py,
|
||||
synapse/util/iterutils.py,
|
||||
synapse/util/metrics.py,
|
||||
synapse/util/macaroons.py,
|
||||
synapse/util/module_loader.py,
|
||||
synapse/util/msisdn.py,
|
||||
synapse/util/stringutils.py,
|
||||
synapse/visibility.py,
|
||||
tests/replication,
|
||||
@ -80,6 +85,7 @@ files =
|
||||
tests/handlers/test_password_providers.py,
|
||||
tests/rest/client/v1/test_login.py,
|
||||
tests/rest/client/v2_alpha/test_auth.py,
|
||||
tests/util/test_itertools.py,
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
[mypy-pymacaroons.*]
|
||||
@ -174,3 +180,9 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-pympler.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-phonenumbers.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-ijson.*]
|
||||
ignore_missing_imports = True
|
||||
|
@ -21,18 +21,18 @@ DISTS = (
|
||||
"debian:buster",
|
||||
"debian:bullseye",
|
||||
"debian:sid",
|
||||
"ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23)
|
||||
"ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
|
||||
"ubuntu:groovy", # 20.10 (EOL 2021-07-07)
|
||||
"ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23)
|
||||
"ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
|
||||
"ubuntu:groovy", # 20.10 (EOL 2021-07-07)
|
||||
"ubuntu:hirsute", # 21.04 (EOL 2022-01-05)
|
||||
)
|
||||
|
||||
DESC = '''\
|
||||
DESC = """\
|
||||
Builds .debs for synapse, using a Docker image for the build environment.
|
||||
|
||||
By default, builds for all known distributions, but a list of distributions
|
||||
can be passed on the commandline for debugging.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class Builder(object):
|
||||
@ -46,7 +46,7 @@ class Builder(object):
|
||||
"""Build deb for a single distribution"""
|
||||
|
||||
if self._failed:
|
||||
print("not building %s due to earlier failure" % (dist, ))
|
||||
print("not building %s due to earlier failure" % (dist,))
|
||||
raise Exception("failed")
|
||||
|
||||
try:
|
||||
@ -68,48 +68,65 @@ class Builder(object):
|
||||
# we tend to get source packages which are full of debs. (We could hack
|
||||
# around that with more magic in the build_debian.sh script, but that
|
||||
# doesn't solve the problem for natively-run dpkg-buildpakage).
|
||||
debsdir = os.path.join(projdir, '../debs')
|
||||
debsdir = os.path.join(projdir, "../debs")
|
||||
os.makedirs(debsdir, exist_ok=True)
|
||||
|
||||
if self.redirect_stdout:
|
||||
logfile = os.path.join(debsdir, "%s.buildlog" % (tag, ))
|
||||
logfile = os.path.join(debsdir, "%s.buildlog" % (tag,))
|
||||
print("building %s: directing output to %s" % (dist, logfile))
|
||||
stdout = open(logfile, "w")
|
||||
else:
|
||||
stdout = None
|
||||
|
||||
# first build a docker image for the build environment
|
||||
subprocess.check_call([
|
||||
"docker", "build",
|
||||
"--tag", "dh-venv-builder:" + tag,
|
||||
"--build-arg", "distro=" + dist,
|
||||
"-f", "docker/Dockerfile-dhvirtualenv",
|
||||
"docker",
|
||||
], stdout=stdout, stderr=subprocess.STDOUT)
|
||||
subprocess.check_call(
|
||||
[
|
||||
"docker",
|
||||
"build",
|
||||
"--tag",
|
||||
"dh-venv-builder:" + tag,
|
||||
"--build-arg",
|
||||
"distro=" + dist,
|
||||
"-f",
|
||||
"docker/Dockerfile-dhvirtualenv",
|
||||
"docker",
|
||||
],
|
||||
stdout=stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
container_name = "synapse_build_" + tag
|
||||
with self._lock:
|
||||
self.active_containers.add(container_name)
|
||||
|
||||
# then run the build itself
|
||||
subprocess.check_call([
|
||||
"docker", "run",
|
||||
"--rm",
|
||||
"--name", container_name,
|
||||
"--volume=" + projdir + ":/synapse/source:ro",
|
||||
"--volume=" + debsdir + ":/debs",
|
||||
"-e", "TARGET_USERID=%i" % (os.getuid(), ),
|
||||
"-e", "TARGET_GROUPID=%i" % (os.getgid(), ),
|
||||
"-e", "DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""),
|
||||
"dh-venv-builder:" + tag,
|
||||
], stdout=stdout, stderr=subprocess.STDOUT)
|
||||
subprocess.check_call(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"--name",
|
||||
container_name,
|
||||
"--volume=" + projdir + ":/synapse/source:ro",
|
||||
"--volume=" + debsdir + ":/debs",
|
||||
"-e",
|
||||
"TARGET_USERID=%i" % (os.getuid(),),
|
||||
"-e",
|
||||
"TARGET_GROUPID=%i" % (os.getgid(),),
|
||||
"-e",
|
||||
"DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""),
|
||||
"dh-venv-builder:" + tag,
|
||||
],
|
||||
stdout=stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self.active_containers.remove(container_name)
|
||||
|
||||
if stdout is not None:
|
||||
stdout.close()
|
||||
print("Completed build of %s" % (dist, ))
|
||||
print("Completed build of %s" % (dist,))
|
||||
|
||||
def kill_containers(self):
|
||||
with self._lock:
|
||||
@ -117,9 +134,14 @@ class Builder(object):
|
||||
|
||||
for c in active:
|
||||
print("killing container %s" % (c,))
|
||||
subprocess.run([
|
||||
"docker", "kill", c,
|
||||
], stdout=subprocess.DEVNULL)
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"kill",
|
||||
c,
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
)
|
||||
with self._lock:
|
||||
self.active_containers.remove(c)
|
||||
|
||||
@ -130,31 +152,38 @@ def run_builds(dists, jobs=1, skip_tests=False):
|
||||
def sig(signum, _frame):
|
||||
print("Caught SIGINT")
|
||||
builder.kill_containers()
|
||||
|
||||
signal.signal(signal.SIGINT, sig)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=jobs) as e:
|
||||
res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists)
|
||||
|
||||
# make sure we consume the iterable so that exceptions are raised.
|
||||
for r in res:
|
||||
for _ in res:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=DESC,
|
||||
)
|
||||
parser.add_argument(
|
||||
'-j', '--jobs', type=int, default=1,
|
||||
help='specify the number of builds to run in parallel',
|
||||
"-j",
|
||||
"--jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="specify the number of builds to run in parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-check', action='store_true',
|
||||
help='skip running tests after building',
|
||||
"--no-check",
|
||||
action="store_true",
|
||||
help="skip running tests after building",
|
||||
)
|
||||
parser.add_argument(
|
||||
'dist', nargs='*', default=DISTS,
|
||||
help='a list of distributions to build for. Default: %(default)s',
|
||||
"dist",
|
||||
nargs="*",
|
||||
default=DISTS,
|
||||
help="a list of distributions to build for. Default: %(default)s",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_builds(dists=args.dist, jobs=args.jobs, skip_tests=args.no_check)
|
||||
|
@ -10,6 +10,9 @@
|
||||
# checkout by setting the COMPLEMENT_DIR environment variable to the
|
||||
# filepath of a local Complement checkout.
|
||||
#
|
||||
# By default Synapse is run in monolith mode. This can be overridden by
|
||||
# setting the WORKERS environment variable.
|
||||
#
|
||||
# A regular expression of test method names can be supplied as the first
|
||||
# argument to the script. Complement will then only run those tests. If
|
||||
# no regex is supplied, all tests are run. For example;
|
||||
@ -32,10 +35,26 @@ if [[ -z "$COMPLEMENT_DIR" ]]; then
|
||||
echo "Checkout available at 'complement-master'"
|
||||
fi
|
||||
|
||||
# If we're using workers, modify the docker files slightly.
|
||||
if [[ -n "$WORKERS" ]]; then
|
||||
BASE_IMAGE=matrixdotorg/synapse-workers
|
||||
BASE_DOCKERFILE=docker/Dockerfile-workers
|
||||
export COMPLEMENT_BASE_IMAGE=complement-synapse-workers
|
||||
COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile
|
||||
# And provide some more configuration to complement.
|
||||
export COMPLEMENT_CA=true
|
||||
export COMPLEMENT_VERSION_CHECK_ITERATIONS=500
|
||||
else
|
||||
BASE_IMAGE=matrixdotorg/synapse
|
||||
BASE_DOCKERFILE=docker/Dockerfile
|
||||
export COMPLEMENT_BASE_IMAGE=complement-synapse
|
||||
COMPLEMENT_DOCKERFILE=Synapse.Dockerfile
|
||||
fi
|
||||
|
||||
# Build the base Synapse image from the local checkout
|
||||
docker build -t matrixdotorg/synapse -f docker/Dockerfile .
|
||||
docker build -t $BASE_IMAGE -f "$BASE_DOCKERFILE" .
|
||||
# Build the Synapse monolith image from Complement, based on the above image we just built
|
||||
docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
|
||||
docker build -t $COMPLEMENT_BASE_IMAGE -f "$COMPLEMENT_DIR/dockerfiles/$COMPLEMENT_DOCKERFILE" "$COMPLEMENT_DIR/dockerfiles"
|
||||
|
||||
cd "$COMPLEMENT_DIR"
|
||||
|
||||
@ -46,4 +65,4 @@ if [[ -n "$1" ]]; then
|
||||
fi
|
||||
|
||||
# Run the tests!
|
||||
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist,msc2946,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
|
||||
go test -v -tags synapse_blacklist,msc2946,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
|
||||
|
@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@ -54,15 +53,9 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate):
|
||||
"server_name": server_name,
|
||||
"verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
|
||||
"valid_until_ts": valid_until,
|
||||
"tls_fingerprints": [fingerprint(certificate)],
|
||||
}
|
||||
|
||||
|
||||
def fingerprint(certificate):
|
||||
finger = hashlib.sha256(certificate)
|
||||
return {"sha256": encode_base64(finger.digest())}
|
||||
|
||||
|
||||
def rows_v2(server, json):
|
||||
valid_until = json["valid_until_ts"]
|
||||
key_json = encode_canonical_json(json)
|
||||
|
@ -80,8 +80,22 @@ else
|
||||
# then lint everything!
|
||||
if [[ -z ${files+x} ]]; then
|
||||
# Lint all source code files and directories
|
||||
# Note: this list aims the mirror the one in tox.ini
|
||||
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
|
||||
# Note: this list aims to mirror the one in tox.ini
|
||||
files=(
|
||||
"synapse" "docker" "tests"
|
||||
# annoyingly, black doesn't find these so we have to list them
|
||||
"scripts/export_signing_key"
|
||||
"scripts/generate_config"
|
||||
"scripts/generate_log_config"
|
||||
"scripts/hash_password"
|
||||
"scripts/register_new_matrix_user"
|
||||
"scripts/synapse_port_db"
|
||||
"scripts-dev"
|
||||
"scripts-dev/build_debian_packages"
|
||||
"scripts-dev/sign_json"
|
||||
"scripts-dev/update_database"
|
||||
"contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite"
|
||||
)
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -30,7 +30,11 @@ def exit(status: int = 0, message: Optional[str] = None):
|
||||
def format_plain(public_key: nacl.signing.VerifyKey):
|
||||
print(
|
||||
"%s:%s %s"
|
||||
% (public_key.alg, public_key.version, encode_verify_key_base64(public_key),)
|
||||
% (
|
||||
public_key.alg,
|
||||
public_key.version,
|
||||
encode_verify_key_base64(public_key),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -50,7 +54,10 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read",
|
||||
"key_file",
|
||||
nargs="+",
|
||||
type=argparse.FileType("r"),
|
||||
help="The key file to read",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -63,7 +70,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--expiry-ts",
|
||||
type=int,
|
||||
default=int(time.time() * 1000) + 6*3600000,
|
||||
default=int(time.time() * 1000) + 6 * 3600000,
|
||||
help=(
|
||||
"The expiry time to use for -x, in milliseconds since 1970. The default "
|
||||
"is (now+6h)."
|
||||
|
@ -11,23 +11,22 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--config-dir",
|
||||
default="CONFDIR",
|
||||
|
||||
help="The path where the config files are kept. Used to create filenames for "
|
||||
"things like the log config and the signing key. Default: %(default)s",
|
||||
"things like the log config and the signing key. Default: %(default)s",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="DATADIR",
|
||||
help="The path where the data files are kept. Used to create filenames for "
|
||||
"things like the database and media store. Default: %(default)s",
|
||||
"things like the database and media store. Default: %(default)s",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-name",
|
||||
default="SERVERNAME",
|
||||
help="The server name. Used to initialise the server_name config param, but also "
|
||||
"used in the names of some of the config files. Default: %(default)s",
|
||||
"used in the names of some of the config files. Default: %(default)s",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -41,21 +40,22 @@ if __name__ == "__main__":
|
||||
"--generate-secrets",
|
||||
action="store_true",
|
||||
help="Enable generation of new secrets for things like the macaroon_secret_key."
|
||||
"By default, these parameters will be left unset."
|
||||
"By default, these parameters will be left unset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o", "--output-file",
|
||||
type=argparse.FileType('w'),
|
||||
"-o",
|
||||
"--output-file",
|
||||
type=argparse.FileType("w"),
|
||||
default=sys.stdout,
|
||||
help="File to write the configuration to. Default: stdout",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--header-file",
|
||||
type=argparse.FileType('r'),
|
||||
type=argparse.FileType("r"),
|
||||
help="File from which to read a header, which will be printed before the "
|
||||
"generated config.",
|
||||
"generated config.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=argparse.FileType('r'),
|
||||
type=argparse.FileType("r"),
|
||||
help=(
|
||||
"Path to server config file. "
|
||||
"Used to read in bcrypt_rounds and password_pepper."
|
||||
@ -72,8 +72,8 @@ if __name__ == "__main__":
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
hashed = bcrypt.hashpw(
|
||||
pw.encode('utf8') + password_pepper.encode("utf8"),
|
||||
pw.encode("utf8") + password_pepper.encode("utf8"),
|
||||
bcrypt.gensalt(bcrypt_rounds),
|
||||
).decode('ascii')
|
||||
).decode("ascii")
|
||||
|
||||
print(hashed)
|
||||
|
@ -294,8 +294,7 @@ class Porter(object):
|
||||
return table, already_ported, total_to_port, forward_chunk, backward_chunk
|
||||
|
||||
async def get_table_constraints(self) -> Dict[str, Set[str]]:
|
||||
"""Returns a map of tables that have foreign key constraints to tables they depend on.
|
||||
"""
|
||||
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
|
||||
|
||||
def _get_constraints(txn):
|
||||
# We can pull the information about foreign key constraints out from
|
||||
@ -504,7 +503,9 @@ class Porter(object):
|
||||
return
|
||||
|
||||
def build_db_store(
|
||||
self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
|
||||
self,
|
||||
db_config: DatabaseConnectionConfig,
|
||||
allow_outdated_version: bool = False,
|
||||
):
|
||||
"""Builds and returns a database store using the provided configuration.
|
||||
|
||||
@ -740,7 +741,7 @@ class Porter(object):
|
||||
return col
|
||||
|
||||
outrows = []
|
||||
for i, row in enumerate(rows):
|
||||
for row in rows:
|
||||
try:
|
||||
outrows.append(
|
||||
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
|
||||
@ -890,8 +891,7 @@ class Porter(object):
|
||||
await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
||||
|
||||
async def _setup_events_stream_seqs(self) -> None:
|
||||
"""Set the event stream sequences to the correct values.
|
||||
"""
|
||||
"""Set the event stream sequences to the correct values."""
|
||||
|
||||
# We get called before we've ported the events table, so we need to
|
||||
# fetch the current positions from the SQLite store.
|
||||
@ -920,12 +920,14 @@ class Porter(object):
|
||||
)
|
||||
|
||||
await self.postgres_store.db_pool.runInteraction(
|
||||
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
|
||||
"_setup_events_stream_seqs",
|
||||
_setup_events_stream_seqs_set_pos,
|
||||
)
|
||||
|
||||
async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
|
||||
"""Set a sequence to the correct value.
|
||||
"""
|
||||
async def _setup_sequence(
|
||||
self, sequence_name: str, stream_id_tables: Iterable[str]
|
||||
) -> None:
|
||||
"""Set a sequence to the correct value."""
|
||||
current_stream_ids = []
|
||||
for stream_id_table in stream_id_tables:
|
||||
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
@ -939,20 +941,25 @@ class Porter(object):
|
||||
next_id = max(current_stream_ids) + 1
|
||||
|
||||
def r(txn):
|
||||
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
|
||||
txn.execute(sql + " %s", (next_id, ))
|
||||
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
|
||||
txn.execute(sql + " %s", (next_id,))
|
||||
|
||||
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
|
||||
await self.postgres_store.db_pool.runInteraction(
|
||||
"_setup_%s" % (sequence_name,), r
|
||||
)
|
||||
|
||||
async def _setup_auth_chain_sequence(self) -> None:
|
||||
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
|
||||
table="event_auth_chains",
|
||||
keyvalues={},
|
||||
retcol="MAX(chain_id)",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def r(txn):
|
||||
txn.execute(
|
||||
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
||||
(curr_chain_id,),
|
||||
(curr_chain_id + 1,),
|
||||
)
|
||||
|
||||
if curr_chain_id is not None:
|
||||
@ -968,8 +975,7 @@ class Porter(object):
|
||||
|
||||
|
||||
class Progress(object):
|
||||
"""Used to report progress of the port
|
||||
"""
|
||||
"""Used to report progress of the port"""
|
||||
|
||||
def __init__(self):
|
||||
self.tables = {}
|
||||
@ -994,8 +1000,7 @@ class Progress(object):
|
||||
|
||||
|
||||
class CursesProgress(Progress):
|
||||
"""Reports progress to a curses window
|
||||
"""
|
||||
"""Reports progress to a curses window"""
|
||||
|
||||
def __init__(self, stdscr):
|
||||
self.stdscr = stdscr
|
||||
@ -1020,7 +1025,7 @@ class CursesProgress(Progress):
|
||||
|
||||
self.total_processed = 0
|
||||
self.total_remaining = 0
|
||||
for table, data in self.tables.items():
|
||||
for data in self.tables.values():
|
||||
self.total_processed += data["num_done"] - data["start"]
|
||||
self.total_remaining += data["total"] - data["num_done"]
|
||||
|
||||
@ -1111,8 +1116,7 @@ class CursesProgress(Progress):
|
||||
|
||||
|
||||
class TerminalProgress(Progress):
|
||||
"""Just prints progress to the terminal
|
||||
"""
|
||||
"""Just prints progress to the terminal"""
|
||||
|
||||
def update(self, table, num_done):
|
||||
super(TerminalProgress, self).update(table, num_done)
|
||||
|
@ -47,7 +47,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.34.0"
|
||||
__version__ = "1.35.0rc1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
@ -87,6 +87,7 @@ class Auth:
|
||||
)
|
||||
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||
|
||||
async def check_from_context(
|
||||
self, room_version: str, event, context, do_sig_check=True
|
||||
@ -208,6 +209,8 @@ class Auth:
|
||||
opentracing.set_tag("authenticated_entity", user_id)
|
||||
opentracing.set_tag("user_id", user_id)
|
||||
opentracing.set_tag("appservice_id", app_service.id)
|
||||
if user_id in self._force_tracing_for_users:
|
||||
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
||||
|
||||
return requester
|
||||
|
||||
@ -260,6 +263,8 @@ class Auth:
|
||||
opentracing.set_tag("user_id", user_info.user_id)
|
||||
if device_id:
|
||||
opentracing.set_tag("device_id", device_id)
|
||||
if user_info.token_owner in self._force_tracing_for_users:
|
||||
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
||||
|
||||
return requester
|
||||
except KeyError:
|
||||
|
@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.rest.admin import register_servlets_for_media_repo
|
||||
from synapse.rest.client.v1 import events, login, presence, room
|
||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||
@ -237,7 +236,6 @@ class GenericWorkerSlavedStore(
|
||||
DirectoryStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
SlavedTransactionStore,
|
||||
SlavedProfileStore,
|
||||
SlavedClientIpStore,
|
||||
SlavedFilteringStore,
|
||||
|
@ -29,9 +29,26 @@ class ExperimentalConfig(Config):
|
||||
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
||||
|
||||
# Spaces (MSC1772, MSC2946, MSC3083, etc)
|
||||
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
|
||||
self.spaces_enabled = experimental.get("spaces_enabled", True) # type: bool
|
||||
if self.spaces_enabled:
|
||||
KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
|
||||
|
||||
# MSC3026 (busy presence state)
|
||||
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Enable experimental features in Synapse.
|
||||
#
|
||||
# Experimental features might break or be removed without a deprecation
|
||||
# period.
|
||||
#
|
||||
experimental_features:
|
||||
# Support for Spaces (MSC1772), it enables the following:
|
||||
#
|
||||
# * The Spaces Summary API (MSC2946).
|
||||
# * Restricting room membership based on space membership (MSC3083).
|
||||
#
|
||||
# Uncomment to disable support for Spaces.
|
||||
#spaces_enabled: false
|
||||
"""
|
||||
|
@ -59,7 +59,6 @@ class HomeServerConfig(RootConfig):
|
||||
config_classes = [
|
||||
MeowConfig,
|
||||
ServerConfig,
|
||||
ExperimentalConfig,
|
||||
TlsConfig,
|
||||
FederationConfig,
|
||||
CacheConfig,
|
||||
@ -96,4 +95,5 @@ class HomeServerConfig(RootConfig):
|
||||
TracerConfig,
|
||||
WorkerConfig,
|
||||
RedisConfig,
|
||||
ExperimentalConfig,
|
||||
]
|
||||
|
@ -349,4 +349,4 @@ class RegistrationConfig(Config):
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.enable_registration is not None:
|
||||
self.enable_registration = bool(strtobool(str(args.enable_registration)))
|
||||
self.enable_registration = strtobool(str(args.enable_registration))
|
||||
|
@ -164,7 +164,13 @@ class SAML2Config(Config):
|
||||
config_path = saml2_config.get("config_path", None)
|
||||
if config_path is not None:
|
||||
mod = load_python_module(config_path)
|
||||
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
|
||||
config = getattr(mod, "CONFIG", None)
|
||||
if config is None:
|
||||
raise ConfigError(
|
||||
"Config path specified by saml2_config.config_path does not "
|
||||
"have a CONFIG property."
|
||||
)
|
||||
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
|
||||
|
||||
import saml2.config
|
||||
|
||||
|
@ -16,11 +16,8 @@ import logging
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import List, Optional, Pattern
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from OpenSSL import SSL, crypto
|
||||
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
|
||||
|
||||
@ -83,13 +80,6 @@ class TlsConfig(Config):
|
||||
"configured."
|
||||
)
|
||||
|
||||
self._original_tls_fingerprints = config.get("tls_fingerprints", [])
|
||||
|
||||
if self._original_tls_fingerprints is None:
|
||||
self._original_tls_fingerprints = []
|
||||
|
||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||
|
||||
# Whether to verify certificates on outbound federation traffic
|
||||
self.federation_verify_certificates = config.get(
|
||||
"federation_verify_certificates", True
|
||||
@ -248,19 +238,6 @@ class TlsConfig(Config):
|
||||
e,
|
||||
)
|
||||
|
||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||
|
||||
if self.tls_certificate:
|
||||
# Check that our own certificate is included in the list of fingerprints
|
||||
# and include it if it is not.
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1, self.tls_certificate
|
||||
)
|
||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||
sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints}
|
||||
if sha256_fingerprint not in sha256_fingerprints:
|
||||
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
|
||||
|
||||
def generate_config_section(
|
||||
self,
|
||||
config_dir_path,
|
||||
@ -443,33 +420,6 @@ class TlsConfig(Config):
|
||||
# If unspecified, we will use CONFDIR/client.key.
|
||||
#
|
||||
account_key_file: %(default_acme_account_file)s
|
||||
|
||||
# List of allowed TLS fingerprints for this server to publish along
|
||||
# with the signing keys for this server. Other matrix servers that
|
||||
# make HTTPS requests to this server will check that the TLS
|
||||
# certificates returned by this server match one of the fingerprints.
|
||||
#
|
||||
# Synapse automatically adds the fingerprint of its own certificate
|
||||
# to the list. So if federation traffic is handled directly by synapse
|
||||
# then no modification to the list is required.
|
||||
#
|
||||
# If synapse is run behind a load balancer that handles the TLS then it
|
||||
# will be necessary to add the fingerprints of the certificates used by
|
||||
# the loadbalancers to this list if they are different to the one
|
||||
# synapse is using.
|
||||
#
|
||||
# Homeservers are permitted to cache the list of TLS fingerprints
|
||||
# returned in the key responses up to the "valid_until_ts" returned in
|
||||
# key. It may be necessary to publish the fingerprints of a new
|
||||
# certificate and wait until the "valid_until_ts" of the previous key
|
||||
# responses have passed before deploying it.
|
||||
#
|
||||
# You can calculate a fingerprint from a given TLS listener via:
|
||||
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
||||
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
||||
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
||||
#
|
||||
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||
"""
|
||||
# Lowercase the string representation of boolean values
|
||||
% {
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Set
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
@ -32,6 +34,8 @@ class TracerConfig(Config):
|
||||
{"sampler": {"type": "const", "param": 1}, "logging": False},
|
||||
)
|
||||
|
||||
self.force_tracing_for_users: Set[str] = set()
|
||||
|
||||
if not self.opentracer_enabled:
|
||||
return
|
||||
|
||||
@ -48,6 +52,19 @@ class TracerConfig(Config):
|
||||
if not isinstance(self.opentracer_whitelist, list):
|
||||
raise ConfigError("Tracer homeserver_whitelist config is malformed")
|
||||
|
||||
force_tracing_for_users = opentracing_config.get("force_tracing_for_users", [])
|
||||
if not isinstance(force_tracing_for_users, list):
|
||||
raise ConfigError(
|
||||
"Expected a list", ("opentracing", "force_tracing_for_users")
|
||||
)
|
||||
for i, u in enumerate(force_tracing_for_users):
|
||||
if not isinstance(u, str):
|
||||
raise ConfigError(
|
||||
"Expected a string",
|
||||
("opentracing", "force_tracing_for_users", f"index {i}"),
|
||||
)
|
||||
self.force_tracing_for_users.add(u)
|
||||
|
||||
def generate_config_section(cls, **kwargs):
|
||||
return """\
|
||||
## Opentracing ##
|
||||
@ -64,7 +81,8 @@ class TracerConfig(Config):
|
||||
#enabled: true
|
||||
|
||||
# The list of homeservers we wish to send and receive span contexts and span baggage.
|
||||
# See docs/opentracing.rst
|
||||
# See docs/opentracing.rst.
|
||||
#
|
||||
# This is a list of regexes which are matched against the server_name of the
|
||||
# homeserver.
|
||||
#
|
||||
@ -73,19 +91,26 @@ class TracerConfig(Config):
|
||||
#homeserver_whitelist:
|
||||
# - ".*"
|
||||
|
||||
# A list of the matrix IDs of users whose requests will always be traced,
|
||||
# even if the tracing system would otherwise drop the traces due to
|
||||
# probabilistic sampling.
|
||||
#
|
||||
# By default, the list is empty.
|
||||
#
|
||||
#force_tracing_for_users:
|
||||
# - "@user1:server_name"
|
||||
# - "@user2:server_name"
|
||||
|
||||
# Jaeger can be configured to sample traces at different rates.
|
||||
# All configuration options provided by Jaeger can be set here.
|
||||
# Jaeger's configuration mostly related to trace sampling which
|
||||
# Jaeger's configuration is mostly related to trace sampling which
|
||||
# is documented here:
|
||||
# https://www.jaegertracing.io/docs/1.13/sampling/.
|
||||
# https://www.jaegertracing.io/docs/latest/sampling/.
|
||||
#
|
||||
#jaeger_config:
|
||||
# sampler:
|
||||
# type: const
|
||||
# param: 1
|
||||
|
||||
# Logging whether spans were started and reported
|
||||
#
|
||||
# logging:
|
||||
# false
|
||||
"""
|
||||
|
@ -17,7 +17,7 @@ import abc
|
||||
import logging
|
||||
import urllib
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from signedjson.key import (
|
||||
@ -42,6 +42,8 @@ from synapse.api.errors import (
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.config.key import TrustedKeyServer
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import prune_event_dict
|
||||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
@ -69,7 +71,11 @@ class VerifyJsonRequest:
|
||||
Attributes:
|
||||
server_name: The name of the server to verify against.
|
||||
|
||||
json_object: The JSON object to verify.
|
||||
get_json_object: A callback to fetch the JSON object to verify.
|
||||
A callback is used to allow deferring the creation of the JSON
|
||||
object to verify until needed, e.g. for events we can defer
|
||||
creating the redacted copy. This reduces the memory usage when
|
||||
there are large numbers of in flight requests.
|
||||
|
||||
minimum_valid_until_ts: time at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
@ -88,14 +94,50 @@ class VerifyJsonRequest:
|
||||
"""
|
||||
|
||||
server_name = attr.ib(type=str)
|
||||
json_object = attr.ib(type=JsonDict)
|
||||
get_json_object = attr.ib(type=Callable[[], JsonDict])
|
||||
minimum_valid_until_ts = attr.ib(type=int)
|
||||
request_name = attr.ib(type=str)
|
||||
key_ids = attr.ib(init=False, type=List[str])
|
||||
key_ids = attr.ib(type=List[str])
|
||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.key_ids = signature_ids(self.json_object, self.server_name)
|
||||
@staticmethod
|
||||
def from_json_object(
|
||||
server_name: str,
|
||||
json_object: JsonDict,
|
||||
minimum_valid_until_ms: int,
|
||||
request_name: str,
|
||||
):
|
||||
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
||||
object for the given server.
|
||||
"""
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
return VerifyJsonRequest(
|
||||
server_name,
|
||||
lambda: json_object,
|
||||
minimum_valid_until_ms,
|
||||
request_name=request_name,
|
||||
key_ids=key_ids,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_event(
|
||||
server_name: str,
|
||||
event: EventBase,
|
||||
minimum_valid_until_ms: int,
|
||||
):
|
||||
"""Create a VerifyJsonRequest to verify all signatures on an event
|
||||
object for the given server.
|
||||
"""
|
||||
key_ids = list(event.signatures.get(server_name, []))
|
||||
return VerifyJsonRequest(
|
||||
server_name,
|
||||
# We defer creating the redacted json object, as it uses a lot more
|
||||
# memory than the Event object itself.
|
||||
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
||||
minimum_valid_until_ms,
|
||||
request_name=event.event_id,
|
||||
key_ids=key_ids,
|
||||
)
|
||||
|
||||
|
||||
class KeyLookupError(ValueError):
|
||||
@ -147,8 +189,13 @@ class Keyring:
|
||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
||||
errbacks with an error
|
||||
"""
|
||||
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
||||
requests = (req,)
|
||||
request = VerifyJsonRequest.from_json_object(
|
||||
server_name,
|
||||
json_object,
|
||||
validity_time,
|
||||
request_name,
|
||||
)
|
||||
requests = (request,)
|
||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
||||
|
||||
def verify_json_objects_for_server(
|
||||
@ -175,10 +222,41 @@ class Keyring:
|
||||
logcontext.
|
||||
"""
|
||||
return self._verify_objects(
|
||||
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
||||
VerifyJsonRequest.from_json_object(
|
||||
server_name, json_object, validity_time, request_name
|
||||
)
|
||||
for server_name, json_object, validity_time, request_name in server_and_json
|
||||
)
|
||||
|
||||
def verify_events_for_server(
|
||||
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Bulk verification of signatures on events.
|
||||
|
||||
Args:
|
||||
server_and_events:
|
||||
Iterable of `(server_name, event, validity_time)` tuples.
|
||||
|
||||
`server_name` is which server we are verifying the signature for
|
||||
on the event.
|
||||
|
||||
`event` is the event that we'll verify the signatures of for
|
||||
the given `server_name`.
|
||||
|
||||
`validity_time` is a timestamp at which the signing key must be
|
||||
valid.
|
||||
|
||||
Returns:
|
||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||
or failure to verify each event's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
"""
|
||||
return self._verify_objects(
|
||||
VerifyJsonRequest.from_event(server_name, event, validity_time)
|
||||
for server_name, event, validity_time in server_and_events
|
||||
)
|
||||
|
||||
def _verify_objects(
|
||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
||||
) -> List[defer.Deferred]:
|
||||
@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
||||
with PreserveLoggingContext():
|
||||
_, key_id, verify_key = await verify_request.key_ready
|
||||
|
||||
json_object = verify_request.json_object
|
||||
json_object = verify_request.get_json_object()
|
||||
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
|
@ -137,11 +137,7 @@ class FederationBase:
|
||||
return deferreds
|
||||
|
||||
|
||||
class PduToCheckSig(
|
||||
namedtuple(
|
||||
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
|
||||
)
|
||||
):
|
||||
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
|
||||
pass
|
||||
|
||||
|
||||
@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
|
||||
pdus_to_check = [
|
||||
PduToCheckSig(
|
||||
pdu=p,
|
||||
redacted_pdu_json=prune_event(p).get_pdu_json(),
|
||||
sender_domain=get_domain_from_id(p.sender),
|
||||
deferreds=[],
|
||||
)
|
||||
@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
|
||||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server(
|
||||
more_deferreds = keyring.verify_events_for_server(
|
||||
[
|
||||
(
|
||||
p.sender_domain,
|
||||
p.redacted_pdu_json,
|
||||
p.pdu,
|
||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
p.pdu.event_id,
|
||||
)
|
||||
for p in pdus_to_check_sender
|
||||
]
|
||||
@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
|
||||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
||||
]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server(
|
||||
more_deferreds = keyring.verify_events_for_server(
|
||||
[
|
||||
(
|
||||
get_domain_from_id(p.pdu.event_id),
|
||||
p.redacted_pdu_json,
|
||||
p.pdu,
|
||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
p.pdu.event_id,
|
||||
)
|
||||
for p in pdus_to_check_event_id
|
||||
]
|
||||
|
@ -55,6 +55,7 @@ from synapse.api.room_versions import (
|
||||
)
|
||||
from synapse.events import EventBase, builder
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.federation.transport.client import SendJoinResponse
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
@ -665,19 +666,10 @@ class FederationClient(FederationBase):
|
||||
"""
|
||||
|
||||
async def send_request(destination) -> Dict[str, Any]:
|
||||
content = await self._do_send_join(destination, pdu)
|
||||
response = await self._do_send_join(room_version, destination, pdu)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
state = [
|
||||
event_from_pdu_json(p, room_version, outlier=True)
|
||||
for p in content.get("state", [])
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, room_version, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
state = response.state
|
||||
auth_chain = response.auth_events
|
||||
|
||||
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
|
||||
|
||||
@ -752,11 +744,14 @@ class FederationClient(FederationBase):
|
||||
|
||||
return await self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||
async def _do_send_join(
|
||||
self, room_version: RoomVersion, destination: str, pdu: EventBase
|
||||
) -> SendJoinResponse:
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
return await self.transport_layer.send_join_v2(
|
||||
room_version=room_version,
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
@ -771,17 +766,14 @@ class FederationClient(FederationBase):
|
||||
|
||||
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
||||
|
||||
resp = await self.transport_layer.send_join_v1(
|
||||
return await self.transport_layer.send_join_v1(
|
||||
room_version=room_version,
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
# We expect the v1 API to respond with [200, content], so we only return the
|
||||
# content.
|
||||
return resp[1]
|
||||
|
||||
async def send_invite(
|
||||
self,
|
||||
destination: str,
|
||||
|
@ -17,13 +17,19 @@ import logging
|
||||
import urllib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import attr
|
||||
import ijson
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.api.urls import (
|
||||
FEDERATION_UNSTABLE_PREFIX,
|
||||
FEDERATION_V1_PREFIX,
|
||||
FEDERATION_V2_PREFIX,
|
||||
)
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.http.matrixfederationclient import ByteParser
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import JsonDict
|
||||
|
||||
@ -240,21 +246,36 @@ class TransportLayerClient:
|
||||
return content
|
||||
|
||||
@log_function
|
||||
async def send_join_v1(self, destination, room_id, event_id, content):
|
||||
async def send_join_v1(
|
||||
self,
|
||||
room_version,
|
||||
destination,
|
||||
room_id,
|
||||
event_id,
|
||||
content,
|
||||
) -> "SendJoinResponse":
|
||||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
parser=SendJoinParser(room_version, v1_api=True),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@log_function
|
||||
async def send_join_v2(self, destination, room_id, event_id, content):
|
||||
async def send_join_v2(
|
||||
self, room_version, destination, room_id, event_id, content
|
||||
) -> "SendJoinResponse":
|
||||
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
parser=SendJoinParser(room_version, v1_api=False),
|
||||
)
|
||||
|
||||
return response
|
||||
@ -1053,3 +1074,59 @@ def _create_v2_path(path, *args):
|
||||
str
|
||||
"""
|
||||
return _create_path(FEDERATION_V2_PREFIX, path, *args)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class SendJoinResponse:
|
||||
"""The parsed response of a `/send_join` request."""
|
||||
|
||||
auth_events: List[EventBase]
|
||||
state: List[EventBase]
|
||||
|
||||
|
||||
@ijson.coroutine
|
||||
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
|
||||
"""Helper function for use with `ijson.items_coro` to parse an array of
|
||||
events and add them to the given list.
|
||||
"""
|
||||
|
||||
while True:
|
||||
obj = yield
|
||||
event = make_event_from_dict(obj, room_version)
|
||||
events.append(event)
|
||||
|
||||
|
||||
class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||
"""A parser for the response to `/send_join` requests.
|
||||
|
||||
Args:
|
||||
room_version: The version of the room.
|
||||
v1_api: Whether the response is in the v1 format.
|
||||
"""
|
||||
|
||||
CONTENT_TYPE = "application/json"
|
||||
|
||||
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
||||
self._response = SendJoinResponse([], [])
|
||||
|
||||
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
||||
# prefixing with `item.*`.
|
||||
prefix = "item." if v1_api else ""
|
||||
|
||||
self._coro_state = ijson.items_coro(
|
||||
_event_list_parser(room_version, self._response.state),
|
||||
prefix + "state.item",
|
||||
)
|
||||
self._coro_auth = ijson.items_coro(
|
||||
_event_list_parser(room_version, self._response.auth_events),
|
||||
prefix + "auth_chain.item",
|
||||
)
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
self._coro_state.send(data)
|
||||
self._coro_auth.send(data)
|
||||
|
||||
return len(data)
|
||||
|
||||
def finish(self) -> SendJoinResponse:
|
||||
return self._response
|
||||
|
@ -160,7 +160,7 @@ class Authenticator:
|
||||
# If we get a valid signed request from the other side, its probably
|
||||
# alive
|
||||
retry_timings = await self.store.get_destination_retry_timings(origin)
|
||||
if retry_timings and retry_timings["retry_last_ts"]:
|
||||
if retry_timings and retry_timings.retry_last_ts:
|
||||
run_in_background(self._reset_retry_timings, origin)
|
||||
|
||||
return origin
|
||||
@ -1428,7 +1428,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
|
||||
)
|
||||
|
||||
return 200, await self.handler.federation_space_summary(
|
||||
room_id, suggested_only, max_rooms_per_space, exclude_rooms
|
||||
origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
|
||||
)
|
||||
|
||||
|
||||
|
@ -15,12 +15,9 @@
|
||||
import email.mime.multipart
|
||||
import email.utils
|
||||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
@ -36,9 +33,11 @@ class AccountValidityHandler:
|
||||
self.hs = hs
|
||||
self.config = hs.config
|
||||
self.store = self.hs.get_datastore()
|
||||
self.sendmail = self.hs.get_sendmail()
|
||||
self.send_email_handler = self.hs.get_send_email_handler()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._app_name = self.hs.config.email_app_name
|
||||
|
||||
self._account_validity_enabled = (
|
||||
hs.config.account_validity.account_validity_enabled
|
||||
)
|
||||
@ -63,23 +62,10 @@ class AccountValidityHandler:
|
||||
self._template_text = (
|
||||
hs.config.account_validity.account_validity_template_text
|
||||
)
|
||||
account_validity_renew_email_subject = (
|
||||
self._renew_email_subject = (
|
||||
hs.config.account_validity.account_validity_renew_email_subject
|
||||
)
|
||||
|
||||
try:
|
||||
app_name = hs.config.email_app_name
|
||||
|
||||
self._subject = account_validity_renew_email_subject % {"app": app_name}
|
||||
|
||||
self._from_string = hs.config.email_notif_from % {"app": app_name}
|
||||
except Exception:
|
||||
# If substitution failed, fall back to the bare strings.
|
||||
self._subject = account_validity_renew_email_subject
|
||||
self._from_string = hs.config.email_notif_from
|
||||
|
||||
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
||||
|
||||
# Check the renewal emails to send and send them every 30min.
|
||||
if hs.config.run_background_tasks:
|
||||
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
||||
@ -159,38 +145,17 @@ class AccountValidityHandler:
|
||||
}
|
||||
|
||||
html_text = self._template_html.render(**template_vars)
|
||||
html_part = MIMEText(html_text, "html", "utf8")
|
||||
|
||||
plain_text = self._template_text.render(**template_vars)
|
||||
text_part = MIMEText(plain_text, "plain", "utf8")
|
||||
|
||||
for address in addresses:
|
||||
raw_to = email.utils.parseaddr(address)[1]
|
||||
|
||||
multipart_msg = MIMEMultipart("alternative")
|
||||
multipart_msg["Subject"] = self._subject
|
||||
multipart_msg["From"] = self._from_string
|
||||
multipart_msg["To"] = address
|
||||
multipart_msg["Date"] = email.utils.formatdate()
|
||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
logger.info("Sending renewal email to %s", address)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
self._raw_from,
|
||||
raw_to,
|
||||
multipart_msg.as_string().encode("utf8"),
|
||||
reactor=self.hs.get_reactor(),
|
||||
port=self.hs.config.email_smtp_port,
|
||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
||||
username=self.hs.config.email_smtp_user,
|
||||
password=self.hs.config.email_smtp_pass,
|
||||
requireTransportSecurity=self.hs.config.require_transport_security,
|
||||
)
|
||||
await self.send_email_handler.send_email(
|
||||
email_address=raw_to,
|
||||
subject=self._renew_email_subject,
|
||||
app_name=self._app_name,
|
||||
html=html_text,
|
||||
text=plain_text,
|
||||
)
|
||||
|
||||
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||
|
@ -11,10 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Collection, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -29,46 +31,104 @@ class EventAuthHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
async def can_join_without_invite(
|
||||
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
|
||||
) -> bool:
|
||||
async def check_restricted_join_rules(
|
||||
self,
|
||||
state_ids: StateMap[str],
|
||||
room_version: RoomVersion,
|
||||
user_id: str,
|
||||
prev_member_event: Optional[EventBase],
|
||||
) -> None:
|
||||
"""
|
||||
Check whether a user can join a room without an invite.
|
||||
Check whether a user can join a room without an invite due to restricted join rules.
|
||||
|
||||
When joining a room with restricted joined rules (as defined in MSC3083),
|
||||
the membership of spaces must be checked during join.
|
||||
the membership of spaces must be checked during a room join.
|
||||
|
||||
Args:
|
||||
state_ids: The state of the room as it currently is.
|
||||
room_version: The room version of the room being joined.
|
||||
user_id: The user joining the room.
|
||||
prev_member_event: The current membership event for this user.
|
||||
|
||||
Raises:
|
||||
AuthError if the user cannot join the room.
|
||||
"""
|
||||
# If the member is invited or currently joined, then nothing to do.
|
||||
if prev_member_event and (
|
||||
prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
|
||||
):
|
||||
return
|
||||
|
||||
# This is not a room with a restricted join rule, so we don't need to do the
|
||||
# restricted room specific checks.
|
||||
#
|
||||
# Note: We'll be applying the standard join rule checks later, which will
|
||||
# catch the cases of e.g. trying to join private rooms without an invite.
|
||||
if not await self.has_restricted_join_rules(state_ids, room_version):
|
||||
return
|
||||
|
||||
# Get the spaces which allow access to this room and check if the user is
|
||||
# in any of them.
|
||||
allowed_spaces = await self.get_spaces_that_allow_join(state_ids)
|
||||
if not await self.is_user_in_rooms(allowed_spaces, user_id):
|
||||
raise AuthError(
|
||||
403,
|
||||
"You do not belong to any of the required spaces to join this room.",
|
||||
)
|
||||
|
||||
async def has_restricted_join_rules(
|
||||
self, state_ids: StateMap[str], room_version: RoomVersion
|
||||
) -> bool:
|
||||
"""
|
||||
Return if the room has the proper join rules set for access via spaces.
|
||||
|
||||
Args:
|
||||
state_ids: The state of the room as it currently is.
|
||||
room_version: The room version of the room to query.
|
||||
|
||||
Returns:
|
||||
True if the user can join the room, false otherwise.
|
||||
True if the proper room version and join rules are set for restricted access.
|
||||
"""
|
||||
# This only applies to room versions which support the new join rule.
|
||||
if not room_version.msc3083_join_rules:
|
||||
return True
|
||||
return False
|
||||
|
||||
# If there's no join rule, then it defaults to invite (so this doesn't apply).
|
||||
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
|
||||
if not join_rules_event_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
# If the join rule is not restricted, this doesn't apply.
|
||||
join_rules_event = await self._store.get_event(join_rules_event_id)
|
||||
return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED
|
||||
|
||||
async def get_spaces_that_allow_join(
|
||||
self, state_ids: StateMap[str]
|
||||
) -> Collection[str]:
|
||||
"""
|
||||
Generate a list of spaces which allow access to a room.
|
||||
|
||||
Args:
|
||||
state_ids: The state of the room as it currently is.
|
||||
|
||||
Returns:
|
||||
A collection of spaces which provide membership to the room.
|
||||
"""
|
||||
# If there's no join rule, then it defaults to invite (so this doesn't apply).
|
||||
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
|
||||
if not join_rules_event_id:
|
||||
return ()
|
||||
|
||||
# If the join rule is not restricted, this doesn't apply.
|
||||
join_rules_event = await self._store.get_event(join_rules_event_id)
|
||||
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
|
||||
return True
|
||||
|
||||
# If allowed is of the wrong form, then only allow invited users.
|
||||
allowed_spaces = join_rules_event.content.get("allow", [])
|
||||
if not isinstance(allowed_spaces, list):
|
||||
return False
|
||||
|
||||
# Get the list of joined rooms and see if there's an overlap.
|
||||
joined_rooms = await self._store.get_rooms_for_user(user_id)
|
||||
return ()
|
||||
|
||||
# Pull out the other room IDs, invalid data gets filtered.
|
||||
result = []
|
||||
for space in allowed_spaces:
|
||||
if not isinstance(space, dict):
|
||||
continue
|
||||
@ -77,10 +137,31 @@ class EventAuthHandler:
|
||||
if not isinstance(space_id, str):
|
||||
continue
|
||||
|
||||
# The user was joined to one of the spaces specified, they can join
|
||||
# this room!
|
||||
if space_id in joined_rooms:
|
||||
result.append(space_id)
|
||||
|
||||
return result
|
||||
|
||||
async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
|
||||
"""
|
||||
Check whether a user is a member of any of the provided rooms.
|
||||
|
||||
Args:
|
||||
room_ids: The rooms to check for membership.
|
||||
user_id: The user to check.
|
||||
|
||||
Returns:
|
||||
True if the user is in any of the rooms, false otherwise.
|
||||
"""
|
||||
if not room_ids:
|
||||
return False
|
||||
|
||||
# Get the list of joined rooms and see if there's an overlap.
|
||||
joined_rooms = await self._store.get_rooms_for_user(user_id)
|
||||
|
||||
# Check each room and see if the user is in it.
|
||||
for room_id in room_ids:
|
||||
if room_id in joined_rooms:
|
||||
return True
|
||||
|
||||
# The user was not in any of the required spaces.
|
||||
# The user was not in any of the rooms.
|
||||
return False
|
||||
|
@ -1668,28 +1668,17 @@ class FederationHandler(BaseHandler):
|
||||
# Check if the user is already in the room or invited to the room.
|
||||
user_id = event.state_key
|
||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||
newly_joined = True
|
||||
user_is_invited = False
|
||||
prev_member_event = None
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
user_is_invited = prev_member_event.membership == Membership.INVITE
|
||||
|
||||
# If the member is not already in the room, and not invited, check if
|
||||
# they should be allowed access via membership in a space.
|
||||
if (
|
||||
newly_joined
|
||||
and not user_is_invited
|
||||
and not await self._event_auth_handler.can_join_without_invite(
|
||||
prev_state_ids,
|
||||
event.room_version,
|
||||
user_id,
|
||||
)
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"You do not belong to any of the required spaces to join this room.",
|
||||
)
|
||||
# Check if the member should be allowed access via membership in a space.
|
||||
await self._event_auth_handler.check_restricted_join_rules(
|
||||
prev_state_ids,
|
||||
event.room_version,
|
||||
user_id,
|
||||
prev_member_event,
|
||||
)
|
||||
|
||||
# Persist the event.
|
||||
await self._auth_and_persist_event(origin, event, context)
|
||||
|
@ -222,9 +222,21 @@ class BasePresenceHandler(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set_state(
|
||||
self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
|
||||
self,
|
||||
target_user: UserID,
|
||||
state: JsonDict,
|
||||
ignore_status_msg: bool = False,
|
||||
force_notify: bool = False,
|
||||
) -> None:
|
||||
"""Set the presence state of the user. """
|
||||
"""Set the presence state of the user.
|
||||
|
||||
Args:
|
||||
target_user: The ID of the user to set the presence state of.
|
||||
state: The presence state as a JSON dictionary.
|
||||
ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
|
||||
If False, the user's current status will be updated.
|
||||
force_notify: Whether to force notification of the update to clients.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def bump_presence_active_time(self, user: UserID):
|
||||
@ -296,6 +308,51 @@ class BasePresenceHandler(abc.ABC):
|
||||
for destinations, states in hosts_and_states:
|
||||
self._federation.send_presence_to_destinations(states, destinations)
|
||||
|
||||
async def send_full_presence_to_users(self, user_ids: Collection[str]):
|
||||
"""
|
||||
Adds to the list of users who should receive a full snapshot of presence
|
||||
upon their next sync. Note that this only works for local users.
|
||||
|
||||
Then, grabs the current presence state for a given set of users and adds it
|
||||
to the top of the presence stream.
|
||||
|
||||
Args:
|
||||
user_ids: The IDs of the local users to send full presence to.
|
||||
"""
|
||||
# Retrieve one of the users from the given set
|
||||
if not user_ids:
|
||||
raise Exception(
|
||||
"send_full_presence_to_users must be called with at least one user"
|
||||
)
|
||||
user_id = next(iter(user_ids))
|
||||
|
||||
# Mark all users as receiving full presence on their next sync
|
||||
await self.store.add_users_to_send_full_presence_to(user_ids)
|
||||
|
||||
# Add a new entry to the presence stream. Since we use stream tokens to determine whether a
|
||||
# local user should receive a full snapshot of presence when they sync, we need to bump the
|
||||
# presence stream so that subsequent syncs with no presence activity in between won't result
|
||||
# in the client receiving multiple full snapshots of presence.
|
||||
#
|
||||
# If we bump the stream ID, then the user will get a higher stream token next sync, and thus
|
||||
# correctly won't receive a second snapshot.
|
||||
|
||||
# Get the current presence state for one of the users (defaults to offline if not found)
|
||||
current_presence_state = await self.get_state(UserID.from_string(user_id))
|
||||
|
||||
# Convert the UserPresenceState object into a serializable dict
|
||||
state = {
|
||||
"presence": current_presence_state.state,
|
||||
"status_message": current_presence_state.status_msg,
|
||||
}
|
||||
|
||||
# Copy the presence state to the tip of the presence stream.
|
||||
|
||||
# We set force_notify=True here so that this presence update is guaranteed to
|
||||
# increment the presence stream ID (which resending the current user's presence
|
||||
# otherwise would not do).
|
||||
await self.set_state(UserID.from_string(user_id), state, force_notify=True)
|
||||
|
||||
|
||||
class _NullContextManager(ContextManager[None]):
|
||||
"""A context manager which does nothing."""
|
||||
@ -480,8 +537,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||
target_user: UserID,
|
||||
state: JsonDict,
|
||||
ignore_status_msg: bool = False,
|
||||
force_notify: bool = False,
|
||||
) -> None:
|
||||
"""Set the presence state of the user."""
|
||||
"""Set the presence state of the user.
|
||||
|
||||
Args:
|
||||
target_user: The ID of the user to set the presence state of.
|
||||
state: The presence state as a JSON dictionary.
|
||||
ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
|
||||
If False, the user's current status will be updated.
|
||||
force_notify: Whether to force notification of the update to clients.
|
||||
"""
|
||||
presence = state["presence"]
|
||||
|
||||
valid_presence = (
|
||||
@ -508,6 +574,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||
user_id=user_id,
|
||||
state=state,
|
||||
ignore_status_msg=ignore_status_msg,
|
||||
force_notify=force_notify,
|
||||
)
|
||||
|
||||
async def bump_presence_active_time(self, user: UserID) -> None:
|
||||
@ -677,13 +744,19 @@ class PresenceHandler(BasePresenceHandler):
|
||||
[self.user_to_current_state[user_id] for user_id in unpersisted]
|
||||
)
|
||||
|
||||
async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
|
||||
async def _update_states(
|
||||
self, new_states: Iterable[UserPresenceState], force_notify: bool = False
|
||||
) -> None:
|
||||
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
||||
the notifier and federation if and only if the changed presence state
|
||||
should be sent to clients/servers.
|
||||
|
||||
Args:
|
||||
new_states: The new user presence state updates to process.
|
||||
force_notify: Whether to force notifying clients of this presence state update,
|
||||
even if it doesn't change the state of a user's presence (e.g online -> online).
|
||||
This is currently used to bump the max presence stream ID without changing any
|
||||
user's presence (see PresenceHandler.add_users_to_send_full_presence_to).
|
||||
"""
|
||||
now = self.clock.time_msec()
|
||||
|
||||
@ -720,6 +793,9 @@ class PresenceHandler(BasePresenceHandler):
|
||||
now=now,
|
||||
)
|
||||
|
||||
if force_notify:
|
||||
should_notify = True
|
||||
|
||||
self.user_to_current_state[user_id] = new_state
|
||||
|
||||
if should_notify:
|
||||
@ -1058,9 +1134,21 @@ class PresenceHandler(BasePresenceHandler):
|
||||
await self._update_states(updates)
|
||||
|
||||
async def set_state(
|
||||
self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
|
||||
self,
|
||||
target_user: UserID,
|
||||
state: JsonDict,
|
||||
ignore_status_msg: bool = False,
|
||||
force_notify: bool = False,
|
||||
) -> None:
|
||||
"""Set the presence state of the user."""
|
||||
"""Set the presence state of the user.
|
||||
|
||||
Args:
|
||||
target_user: The ID of the user to set the presence state of.
|
||||
state: The presence state as a JSON dictionary.
|
||||
ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
|
||||
If False, the user's current status will be updated.
|
||||
force_notify: Whether to force notification of the update to clients.
|
||||
"""
|
||||
status_msg = state.get("status_msg", None)
|
||||
presence = state["presence"]
|
||||
|
||||
@ -1091,7 +1179,9 @@ class PresenceHandler(BasePresenceHandler):
|
||||
):
|
||||
new_fields["last_active_ts"] = self.clock.time_msec()
|
||||
|
||||
await self._update_states([prev_state.copy_and_replace(**new_fields)])
|
||||
await self._update_states(
|
||||
[prev_state.copy_and_replace(**new_fields)], force_notify=force_notify
|
||||
)
|
||||
|
||||
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
|
||||
"""Returns whether a user can see another user's presence."""
|
||||
@ -1389,11 +1479,10 @@ class PresenceEventSource:
|
||||
#
|
||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||
#
|
||||
# Same with get_module_api, get_presence_router
|
||||
# Same with get_presence_router:
|
||||
#
|
||||
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
|
||||
self.get_presence_handler = hs.get_presence_handler
|
||||
self.get_module_api = hs.get_module_api
|
||||
self.get_presence_router = hs.get_presence_router
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
@ -1424,16 +1513,21 @@ class PresenceEventSource:
|
||||
stream_change_cache = self.store.presence_stream_cache
|
||||
|
||||
with Measure(self.clock, "presence.get_new_events"):
|
||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||
# This user has been specified by a module to receive all current, online
|
||||
# user presence. Removing from_key and setting include_offline to false
|
||||
# will do effectively this.
|
||||
from_key = None
|
||||
include_offline = False
|
||||
|
||||
if from_key is not None:
|
||||
from_key = int(from_key)
|
||||
|
||||
# Check if this user should receive all current, online user presence. We only
|
||||
# bother to do this if from_key is set, as otherwise the user will receive all
|
||||
# user presence anyways.
|
||||
if await self.store.should_user_receive_full_presence_with_token(
|
||||
user_id, from_key
|
||||
):
|
||||
# This user has been specified by a module to receive all current, online
|
||||
# user presence. Removing from_key and setting include_offline to false
|
||||
# will do effectively this.
|
||||
from_key = None
|
||||
include_offline = False
|
||||
|
||||
max_token = self.store.get_current_presence_token()
|
||||
if from_key == max_token:
|
||||
# This is necessary as due to the way stream ID generators work
|
||||
@ -1467,12 +1561,6 @@ class PresenceEventSource:
|
||||
user_id, include_offline, from_key
|
||||
)
|
||||
|
||||
# Remove the user from the list of users to receive all presence
|
||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||
self.get_module_api()._send_full_presence_to_local_users.remove(
|
||||
user_id
|
||||
)
|
||||
|
||||
return presence_updates, max_token
|
||||
|
||||
# Make mypy happy. users_interested_in should now be a set
|
||||
@ -1522,10 +1610,6 @@ class PresenceEventSource:
|
||||
)
|
||||
presence_updates = list(users_to_state.values())
|
||||
|
||||
# Remove the user from the list of users to receive all presence
|
||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||
self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
|
||||
|
||||
if not include_offline:
|
||||
# Filter out offline presence states
|
||||
presence_updates = self._filter_offline_presence_state(presence_updates)
|
||||
|
@ -260,25 +260,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
newly_joined = True
|
||||
user_is_invited = False
|
||||
prev_member_event = None
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
user_is_invited = prev_member_event.membership == Membership.INVITE
|
||||
|
||||
# If the member is not already in the room and is not accepting an invite,
|
||||
# check if they should be allowed access via membership in a space.
|
||||
if (
|
||||
newly_joined
|
||||
and not user_is_invited
|
||||
and not await self.event_auth_handler.can_join_without_invite(
|
||||
prev_state_ids, event.room_version, user_id
|
||||
)
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"You do not belong to any of the required spaces to join this room.",
|
||||
)
|
||||
# Check if the member should be allowed access via membership in a space.
|
||||
await self.event_auth_handler.check_restricted_join_rules(
|
||||
prev_state_ids, event.room_version, user_id, prev_member_event
|
||||
)
|
||||
|
||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||
# up blocking profile updates.
|
||||
|
98
synapse/handlers/send_email.py
Normal file
98
synapse/handlers/send_email.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2021 The Matrix.org C.I.C. Foundation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import email.utils
|
||||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SendEmailHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
|
||||
self._sendmail = hs.get_sendmail()
|
||||
self._reactor = hs.get_reactor()
|
||||
|
||||
self._from = hs.config.email.email_notif_from
|
||||
self._smtp_host = hs.config.email.email_smtp_host
|
||||
self._smtp_port = hs.config.email.email_smtp_port
|
||||
self._smtp_user = hs.config.email.email_smtp_user
|
||||
self._smtp_pass = hs.config.email.email_smtp_pass
|
||||
self._require_transport_security = hs.config.email.require_transport_security
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
email_address: str,
|
||||
subject: str,
|
||||
app_name: str,
|
||||
html: str,
|
||||
text: str,
|
||||
) -> None:
|
||||
"""Send a multipart email with the given information.
|
||||
|
||||
Args:
|
||||
email_address: The address to send the email to.
|
||||
subject: The email's subject.
|
||||
app_name: The app name to include in the From header.
|
||||
html: The HTML content to include in the email.
|
||||
text: The plain text content to include in the email.
|
||||
"""
|
||||
try:
|
||||
from_string = self._from % {"app": app_name}
|
||||
except (KeyError, TypeError):
|
||||
from_string = self._from
|
||||
|
||||
raw_from = email.utils.parseaddr(from_string)[1]
|
||||
raw_to = email.utils.parseaddr(email_address)[1]
|
||||
|
||||
if raw_to == "":
|
||||
raise RuntimeError("Invalid 'to' address")
|
||||
|
||||
html_part = MIMEText(html, "html", "utf8")
|
||||
text_part = MIMEText(text, "plain", "utf8")
|
||||
|
||||
multipart_msg = MIMEMultipart("alternative")
|
||||
multipart_msg["Subject"] = subject
|
||||
multipart_msg["From"] = from_string
|
||||
multipart_msg["To"] = email_address
|
||||
multipart_msg["Date"] = email.utils.formatdate()
|
||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
logger.info("Sending email to %s" % email_address)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
self._sendmail(
|
||||
self._smtp_host,
|
||||
raw_from,
|
||||
raw_to,
|
||||
multipart_msg.as_string().encode("utf8"),
|
||||
reactor=self._reactor,
|
||||
port=self._smtp_port,
|
||||
requireAuthentication=self._smtp_user is not None,
|
||||
username=self._smtp_user,
|
||||
password=self._smtp_pass,
|
||||
requireTransportSecurity=self._require_transport_security,
|
||||
)
|
||||
)
|
@ -16,11 +16,16 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
HistoryVisibility,
|
||||
Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import format_event_for_client_v2
|
||||
@ -32,7 +37,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# number of rooms to return. We'll stop once we hit this limit.
|
||||
# TODO: allow clients to reduce this with a request param.
|
||||
MAX_ROOMS = 50
|
||||
|
||||
# max number of events to return per room.
|
||||
@ -46,8 +50,7 @@ class SpaceSummaryHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._auth = hs.get_auth()
|
||||
self._room_list_handler = hs.get_room_list_handler()
|
||||
self._state_handler = hs.get_state_handler()
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
self._store = hs.get_datastore()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._server_name = hs.hostname
|
||||
@ -112,28 +115,88 @@ class SpaceSummaryHandler:
|
||||
max_children = max_rooms_per_space if processed_rooms else None
|
||||
|
||||
if is_in_room:
|
||||
rooms, events = await self._summarize_local_room(
|
||||
requester, room_id, suggested_only, max_children
|
||||
room, events = await self._summarize_local_room(
|
||||
requester, None, room_id, suggested_only, max_children
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Query of local room %s returned events %s",
|
||||
room_id,
|
||||
["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
|
||||
)
|
||||
|
||||
if room:
|
||||
rooms_result.append(room)
|
||||
else:
|
||||
rooms, events = await self._summarize_remote_room(
|
||||
fed_rooms, fed_events = await self._summarize_remote_room(
|
||||
queue_entry,
|
||||
suggested_only,
|
||||
max_children,
|
||||
exclude_rooms=processed_rooms,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Query of %s returned rooms %s, events %s",
|
||||
queue_entry.room_id,
|
||||
[room.get("room_id") for room in rooms],
|
||||
["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
|
||||
)
|
||||
# The results over federation might include rooms that the we,
|
||||
# as the requesting server, are allowed to see, but the requesting
|
||||
# user is not permitted see.
|
||||
#
|
||||
# Filter the returned results to only what is accessible to the user.
|
||||
room_ids = set()
|
||||
events = []
|
||||
for room in fed_rooms:
|
||||
fed_room_id = room.get("room_id")
|
||||
if not fed_room_id or not isinstance(fed_room_id, str):
|
||||
continue
|
||||
|
||||
rooms_result.extend(rooms)
|
||||
# The room should only be included in the summary if:
|
||||
# a. the user is in the room;
|
||||
# b. the room is world readable; or
|
||||
# c. the user is in a space that has been granted access to
|
||||
# the room.
|
||||
#
|
||||
# Note that we know the user is not in the root room (which is
|
||||
# why the remote call was made in the first place), but the user
|
||||
# could be in one of the children rooms and we just didn't know
|
||||
# about the link.
|
||||
include_room = room.get("world_readable") is True
|
||||
|
||||
# any rooms returned don't need visiting again
|
||||
processed_rooms.update(cast(str, room.get("room_id")) for room in rooms)
|
||||
# Check if the user is a member of any of the allowed spaces
|
||||
# from the response.
|
||||
allowed_spaces = room.get("allowed_spaces")
|
||||
if (
|
||||
not include_room
|
||||
and allowed_spaces
|
||||
and isinstance(allowed_spaces, list)
|
||||
):
|
||||
include_room = await self._event_auth_handler.is_user_in_rooms(
|
||||
allowed_spaces, requester
|
||||
)
|
||||
|
||||
# Finally, if this isn't the requested room, check ourselves
|
||||
# if we can access the room.
|
||||
if not include_room and fed_room_id != queue_entry.room_id:
|
||||
include_room = await self._is_room_accessible(
|
||||
fed_room_id, requester, None
|
||||
)
|
||||
|
||||
# The user can see the room, include it!
|
||||
if include_room:
|
||||
rooms_result.append(room)
|
||||
room_ids.add(fed_room_id)
|
||||
|
||||
# All rooms returned don't need visiting again (even if the user
|
||||
# didn't have access to them).
|
||||
processed_rooms.add(fed_room_id)
|
||||
|
||||
for event in fed_events:
|
||||
if event.get("room_id") in room_ids:
|
||||
events.append(event)
|
||||
|
||||
logger.debug(
|
||||
"Query of %s returned rooms %s, events %s",
|
||||
room_id,
|
||||
[room.get("room_id") for room in fed_rooms],
|
||||
["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events],
|
||||
)
|
||||
|
||||
# the room we queried may or may not have been returned, but don't process
|
||||
# it again, anyway.
|
||||
@ -159,10 +222,16 @@ class SpaceSummaryHandler:
|
||||
)
|
||||
processed_events.add(ev_key)
|
||||
|
||||
# Before returning to the client, remove the allowed_spaces key for any
|
||||
# rooms.
|
||||
for room in rooms_result:
|
||||
room.pop("allowed_spaces", None)
|
||||
|
||||
return {"rooms": rooms_result, "events": events_result}
|
||||
|
||||
async def federation_space_summary(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
suggested_only: bool,
|
||||
max_rooms_per_space: Optional[int],
|
||||
@ -172,6 +241,8 @@ class SpaceSummaryHandler:
|
||||
Implementation of the space summary Federation API
|
||||
|
||||
Args:
|
||||
origin: The server requesting the spaces summary.
|
||||
|
||||
room_id: room id to start the summary at
|
||||
|
||||
suggested_only: whether we should only return children with the "suggested"
|
||||
@ -206,14 +277,15 @@ class SpaceSummaryHandler:
|
||||
|
||||
logger.debug("Processing room %s", room_id)
|
||||
|
||||
rooms, events = await self._summarize_local_room(
|
||||
None, room_id, suggested_only, max_rooms_per_space
|
||||
room, events = await self._summarize_local_room(
|
||||
None, origin, room_id, suggested_only, max_rooms_per_space
|
||||
)
|
||||
|
||||
processed_rooms.add(room_id)
|
||||
|
||||
rooms_result.extend(rooms)
|
||||
events_result.extend(events)
|
||||
if room:
|
||||
rooms_result.append(room)
|
||||
events_result.extend(events)
|
||||
|
||||
# add any children to the queue
|
||||
room_queue.extend(edge_event["state_key"] for edge_event in events)
|
||||
@ -223,19 +295,27 @@ class SpaceSummaryHandler:
|
||||
async def _summarize_local_room(
|
||||
self,
|
||||
requester: Optional[str],
|
||||
origin: Optional[str],
|
||||
room_id: str,
|
||||
suggested_only: bool,
|
||||
max_children: Optional[int],
|
||||
) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
|
||||
) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]:
|
||||
"""
|
||||
Generate a room entry and a list of event entries for a given room.
|
||||
|
||||
Args:
|
||||
requester: The requesting user, or None if this is over federation.
|
||||
requester:
|
||||
The user requesting the summary, if it is a local request. None
|
||||
if this is a federation request.
|
||||
origin:
|
||||
The server requesting the summary, if it is a federation request.
|
||||
None if this is a local request.
|
||||
room_id: The room ID to summarize.
|
||||
suggested_only: True if only suggested children should be returned.
|
||||
Otherwise, all children are returned.
|
||||
max_children: The maximum number of children to return for this node.
|
||||
max_children:
|
||||
The maximum number of children rooms to include. This is capped
|
||||
to a server-set limit.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
@ -244,8 +324,8 @@ class SpaceSummaryHandler:
|
||||
An iterable of the sorted children events. This may be limited
|
||||
to a maximum size or may include all children.
|
||||
"""
|
||||
if not await self._is_room_accessible(room_id, requester):
|
||||
return (), ()
|
||||
if not await self._is_room_accessible(room_id, requester, origin):
|
||||
return None, ()
|
||||
|
||||
room_entry = await self._build_room_entry(room_id)
|
||||
|
||||
@ -269,7 +349,7 @@ class SpaceSummaryHandler:
|
||||
event_format=format_event_for_client_v2,
|
||||
)
|
||||
)
|
||||
return (room_entry,), events_result
|
||||
return room_entry, events_result
|
||||
|
||||
async def _summarize_remote_room(
|
||||
self,
|
||||
@ -278,6 +358,26 @@ class SpaceSummaryHandler:
|
||||
max_children: Optional[int],
|
||||
exclude_rooms: Iterable[str],
|
||||
) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
|
||||
"""
|
||||
Request room entries and a list of event entries for a given room by querying a remote server.
|
||||
|
||||
Args:
|
||||
room: The room to summarize.
|
||||
suggested_only: True if only suggested children should be returned.
|
||||
Otherwise, all children are returned.
|
||||
max_children:
|
||||
The maximum number of children rooms to include. This is capped
|
||||
to a server-set limit.
|
||||
exclude_rooms:
|
||||
Rooms IDs which do not need to be summarized.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
An iterable of rooms.
|
||||
|
||||
An iterable of the sorted children events. This may be limited
|
||||
to a maximum size or may include all children.
|
||||
"""
|
||||
room_id = room.room_id
|
||||
logger.info("Requesting summary for %s via %s", room_id, room.via)
|
||||
|
||||
@ -309,27 +409,93 @@ class SpaceSummaryHandler:
|
||||
or ev.event_type == EventTypes.SpaceChild
|
||||
)
|
||||
|
||||
async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool:
|
||||
# if we have an authenticated requesting user, first check if they are in the
|
||||
# room
|
||||
async def _is_room_accessible(
|
||||
self, room_id: str, requester: Optional[str], origin: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Calculate whether the room should be shown in the spaces summary.
|
||||
|
||||
It should be included if:
|
||||
|
||||
* The requester is joined or invited to the room.
|
||||
* The requester can join without an invite (per MSC3083).
|
||||
* The origin server has any user that is joined or invited to the room.
|
||||
* The history visibility is set to world readable.
|
||||
|
||||
Args:
|
||||
room_id: The room ID to summarize.
|
||||
requester:
|
||||
The user requesting the summary, if it is a local request. None
|
||||
if this is a federation request.
|
||||
origin:
|
||||
The server requesting the summary, if it is a federation request.
|
||||
None if this is a local request.
|
||||
|
||||
Returns:
|
||||
True if the room should be included in the spaces summary.
|
||||
"""
|
||||
state_ids = await self._store.get_current_state_ids(room_id)
|
||||
|
||||
# If there's no state for the room, it isn't known.
|
||||
if not state_ids:
|
||||
logger.info("room %s is unknown, omitting from summary", room_id)
|
||||
return False
|
||||
|
||||
room_version = await self._store.get_room_version(room_id)
|
||||
|
||||
# if we have an authenticated requesting user, first check if they are able to view
|
||||
# stripped state in the room.
|
||||
if requester:
|
||||
member_event_id = state_ids.get((EventTypes.Member, requester), None)
|
||||
|
||||
# If they're in the room they can see info on it.
|
||||
member_event = None
|
||||
if member_event_id:
|
||||
member_event = await self._store.get_event(member_event_id)
|
||||
if member_event.membership in (Membership.JOIN, Membership.INVITE):
|
||||
return True
|
||||
|
||||
# Otherwise, check if they should be allowed access via membership in a space.
|
||||
try:
|
||||
await self._auth.check_user_in_room(room_id, requester)
|
||||
return True
|
||||
await self._event_auth_handler.check_restricted_join_rules(
|
||||
state_ids, room_version, requester, member_event
|
||||
)
|
||||
except AuthError:
|
||||
# The user doesn't have access due to spaces, but might have access
|
||||
# another way. Keep trying.
|
||||
pass
|
||||
else:
|
||||
return True
|
||||
|
||||
# If this is a request over federation, check if the host is in the room or
|
||||
# is in one of the spaces specified via the join rules.
|
||||
elif origin:
|
||||
if await self._auth.check_host_in_room(room_id, origin):
|
||||
return True
|
||||
|
||||
# Alternately, if the host has a user in any of the spaces specified
|
||||
# for access, then the host can see this room (and should do filtering
|
||||
# if the requester cannot see it).
|
||||
if await self._event_auth_handler.has_restricted_join_rules(
|
||||
state_ids, room_version
|
||||
):
|
||||
allowed_spaces = (
|
||||
await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
|
||||
)
|
||||
for space_id in allowed_spaces:
|
||||
if await self._auth.check_host_in_room(space_id, origin):
|
||||
return True
|
||||
|
||||
# otherwise, check if the room is peekable
|
||||
hist_vis_ev = await self._state_handler.get_current_state(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if hist_vis_ev:
|
||||
hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||
if hist_vis_event_id:
|
||||
hist_vis_ev = await self._store.get_event(hist_vis_event_id)
|
||||
hist_vis = hist_vis_ev.content.get("history_visibility")
|
||||
if hist_vis == HistoryVisibility.WORLD_READABLE:
|
||||
return True
|
||||
|
||||
logger.info(
|
||||
"room %s is unpeekable and user %s is not a member, omitting from summary",
|
||||
"room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
|
||||
room_id,
|
||||
requester,
|
||||
)
|
||||
@ -354,6 +520,15 @@ class SpaceSummaryHandler:
|
||||
if not room_type:
|
||||
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
|
||||
|
||||
room_version = await self._store.get_room_version(room_id)
|
||||
allowed_spaces = None
|
||||
if await self._event_auth_handler.has_restricted_join_rules(
|
||||
current_state_ids, room_version
|
||||
):
|
||||
allowed_spaces = await self._event_auth_handler.get_spaces_that_allow_join(
|
||||
current_state_ids
|
||||
)
|
||||
|
||||
entry = {
|
||||
"room_id": stats["room_id"],
|
||||
"name": stats["name"],
|
||||
@ -367,6 +542,7 @@ class SpaceSummaryHandler:
|
||||
"guest_can_join": stats["guest_access"] == "can_join",
|
||||
"creation_ts": create_event.origin_server_ts,
|
||||
"room_type": room_type,
|
||||
"allowed_spaces": allowed_spaces,
|
||||
}
|
||||
|
||||
# Filter out Nones – rather omit the field altogether
|
||||
@ -430,8 +606,8 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# Order may only contain characters in the range of \x20 (space) to \x7F (~).
|
||||
_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7F]")
|
||||
# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive.
|
||||
_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]")
|
||||
|
||||
|
||||
def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]:
|
||||
|
@ -814,7 +814,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
self.stream.write(data)
|
||||
try:
|
||||
self.stream.write(data)
|
||||
except Exception:
|
||||
self.deferred.errback()
|
||||
return
|
||||
|
||||
self.length += len(data)
|
||||
# The first time the maximum size is exceeded, error and cancel the
|
||||
# connection. dataReceived might be called again if data was received
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import cgi
|
||||
import codecs
|
||||
import logging
|
||||
@ -19,13 +20,24 @@ import sys
|
||||
import typing
|
||||
import urllib.parse
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
import treq
|
||||
from canonicaljson import encode_canonical_json
|
||||
from prometheus_client import Counter
|
||||
from signedjson.sign import sign_json
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import DNSLookupError
|
||||
@ -48,6 +60,7 @@ from synapse.http.client import (
|
||||
BlacklistingAgentWrapper,
|
||||
BlacklistingReactorWrapper,
|
||||
BodyExceededMaxSize,
|
||||
ByteWriteable,
|
||||
encode_query_args,
|
||||
read_body_with_max_size,
|
||||
)
|
||||
@ -88,6 +101,27 @@ _next_id = 1
|
||||
QueryArgs = Dict[str, Union[str, List[str]]]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ByteParser(ByteWriteable, Generic[T], abc.ABC):
|
||||
"""A `ByteWriteable` that has an additional `finish` function that returns
|
||||
the parsed data.
|
||||
"""
|
||||
|
||||
CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore
|
||||
"""The expected content type of the response, e.g. `application/json`. If
|
||||
the content type doesn't match we fail the request.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(self) -> T:
|
||||
"""Called when response has finished streaming and the parser should
|
||||
return the final result (or error).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class MatrixFederationRequest:
|
||||
method = attr.ib(type=str)
|
||||
@ -148,15 +182,32 @@ class MatrixFederationRequest:
|
||||
return self.json
|
||||
|
||||
|
||||
async def _handle_json_response(
|
||||
class JsonParser(ByteParser[Union[JsonDict, list]]):
|
||||
"""A parser that buffers the response and tries to parse it as JSON."""
|
||||
|
||||
CONTENT_TYPE = "application/json"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = StringIO()
|
||||
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
return self._binary_wrapper.write(data)
|
||||
|
||||
def finish(self) -> Union[JsonDict, list]:
|
||||
return json_decoder.decode(self._buffer.getvalue())
|
||||
|
||||
|
||||
async def _handle_response(
|
||||
reactor: IReactorTime,
|
||||
timeout_sec: float,
|
||||
request: MatrixFederationRequest,
|
||||
response: IResponse,
|
||||
start_ms: int,
|
||||
) -> JsonDict:
|
||||
parser: ByteParser[T],
|
||||
) -> T:
|
||||
"""
|
||||
Reads the JSON body of a response, with a timeout
|
||||
Reads the body of a response with a timeout and sends it to a parser
|
||||
|
||||
Args:
|
||||
reactor: twisted reactor, for the timeout
|
||||
@ -164,23 +215,21 @@ async def _handle_json_response(
|
||||
request: the request that triggered the response
|
||||
response: response to the request
|
||||
start_ms: Timestamp when request was made
|
||||
parser: The parser for the response
|
||||
|
||||
Returns:
|
||||
The parsed JSON response
|
||||
The parsed response
|
||||
"""
|
||||
try:
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
buf = StringIO()
|
||||
d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
|
||||
try:
|
||||
check_content_type_is(response.headers, parser.CONTENT_TYPE)
|
||||
|
||||
d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE)
|
||||
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
||||
|
||||
def parse(_len: int):
|
||||
return json_decoder.decode(buf.getvalue())
|
||||
length = await make_deferred_yieldable(d)
|
||||
|
||||
d.addCallback(parse)
|
||||
|
||||
body = await make_deferred_yieldable(d)
|
||||
value = parser.finish()
|
||||
except BodyExceededMaxSize as e:
|
||||
# The response was too big.
|
||||
logger.warning(
|
||||
@ -193,9 +242,9 @@ async def _handle_json_response(
|
||||
)
|
||||
raise RequestSendFailed(e, can_retry=False) from e
|
||||
except ValueError as e:
|
||||
# The JSON content was invalid.
|
||||
# The content was invalid.
|
||||
logger.warning(
|
||||
"{%s} [%s] Failed to parse JSON response - %s %s",
|
||||
"{%s} [%s] Failed to parse response - %s %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
request.method,
|
||||
@ -225,16 +274,17 @@ async def _handle_json_response(
|
||||
time_taken_secs = reactor.seconds() - start_ms / 1000
|
||||
|
||||
logger.info(
|
||||
"{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
|
||||
"{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
time_taken_secs,
|
||||
length,
|
||||
request.method,
|
||||
request.uri.decode("ascii"),
|
||||
)
|
||||
return body
|
||||
return value
|
||||
|
||||
|
||||
class BinaryIOWrapper:
|
||||
@ -671,6 +721,7 @@ class MatrixFederationHttpClient:
|
||||
)
|
||||
return auth_headers
|
||||
|
||||
@overload
|
||||
async def put_json(
|
||||
self,
|
||||
destination: str,
|
||||
@ -683,7 +734,41 @@ class MatrixFederationHttpClient:
|
||||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Literal[None] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def put_json(
|
||||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
async def put_json(
|
||||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser] = None,
|
||||
):
|
||||
"""Sends the specified json data using PUT
|
||||
|
||||
Args:
|
||||
@ -716,6 +801,8 @@ class MatrixFederationHttpClient:
|
||||
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
|
||||
will be attempted before backing off if backing off has been
|
||||
enabled.
|
||||
parser: The parser to use to decode the response. Defaults to
|
||||
parsing as JSON.
|
||||
|
||||
Returns:
|
||||
Succeeds when we get a 2xx HTTP response. The
|
||||
@ -756,8 +843,16 @@ class MatrixFederationHttpClient:
|
||||
else:
|
||||
_sec_timeout = self.default_timeout
|
||||
|
||||
body = await _handle_json_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms
|
||||
if parser is None:
|
||||
parser = JsonParser()
|
||||
|
||||
body = await _handle_response(
|
||||
self.reactor,
|
||||
_sec_timeout,
|
||||
request,
|
||||
response,
|
||||
start_ms,
|
||||
parser=parser,
|
||||
)
|
||||
|
||||
return body
|
||||
@ -830,12 +925,8 @@ class MatrixFederationHttpClient:
|
||||
else:
|
||||
_sec_timeout = self.default_timeout
|
||||
|
||||
body = await _handle_json_response(
|
||||
self.reactor,
|
||||
_sec_timeout,
|
||||
request,
|
||||
response,
|
||||
start_ms,
|
||||
body = await _handle_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||
)
|
||||
return body
|
||||
|
||||
@ -907,8 +998,8 @@ class MatrixFederationHttpClient:
|
||||
else:
|
||||
_sec_timeout = self.default_timeout
|
||||
|
||||
body = await _handle_json_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms
|
||||
body = await _handle_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||
)
|
||||
|
||||
return body
|
||||
@ -975,8 +1066,8 @@ class MatrixFederationHttpClient:
|
||||
else:
|
||||
_sec_timeout = self.default_timeout
|
||||
|
||||
body = await _handle_json_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms
|
||||
body = await _handle_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||
)
|
||||
return body
|
||||
|
||||
@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e):
|
||||
return repr(e)
|
||||
|
||||
|
||||
def check_content_type_is_json(headers: Headers) -> None:
|
||||
def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
|
||||
"""
|
||||
Check that a set of HTTP headers have a Content-Type header, and that it
|
||||
is application/json.
|
||||
is the expected value..
|
||||
|
||||
Args:
|
||||
headers: headers to check
|
||||
|
||||
Raises:
|
||||
RequestSendFailed: if the Content-Type header is missing or isn't JSON
|
||||
RequestSendFailed: if the Content-Type header is missing or doesn't match
|
||||
|
||||
"""
|
||||
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
||||
@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None:
|
||||
|
||||
c_type = content_type_headers[0].decode("ascii") # only the first header
|
||||
val, options = cgi.parse_header(c_type)
|
||||
if val != "application/json":
|
||||
if val != expected_content_type:
|
||||
raise RequestSendFailed(
|
||||
RuntimeError(
|
||||
"Remote server sent Content-Type header of '%s', not 'application/json'"
|
||||
% c_type,
|
||||
f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
|
||||
),
|
||||
can_retry=False,
|
||||
)
|
||||
|
@ -56,14 +56,6 @@ class ModuleApi:
|
||||
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
||||
self._public_room_list_manager = PublicRoomListManager(hs)
|
||||
|
||||
# The next time these users sync, they will receive the current presence
|
||||
# state of all local users. Users are added by send_local_online_presence_to,
|
||||
# and removed after a successful sync.
|
||||
#
|
||||
# We make this a private variable to deter modules from accessing it directly,
|
||||
# though other classes in Synapse will still do so.
|
||||
self._send_full_presence_to_local_users = set()
|
||||
|
||||
@property
|
||||
def http_client(self):
|
||||
"""Allows making outbound HTTP requests to remote resources.
|
||||
@ -405,39 +397,44 @@ class ModuleApi:
|
||||
Updates to remote users will be sent immediately, whereas local users will receive
|
||||
them on their next sync attempt.
|
||||
|
||||
Note that this method can only be run on the main or federation_sender worker
|
||||
processes.
|
||||
Note that this method can only be run on the process that is configured to write to the
|
||||
presence stream. By default this is the main process.
|
||||
"""
|
||||
if not self._hs.should_send_federation():
|
||||
if self._hs._instance_name not in self._hs.config.worker.writers.presence:
|
||||
raise Exception(
|
||||
"send_local_online_presence_to can only be run "
|
||||
"on processes that send federation",
|
||||
"on the process that is configured to write to the "
|
||||
"presence stream (by default this is the main process)",
|
||||
)
|
||||
|
||||
local_users = set()
|
||||
remote_users = set()
|
||||
for user in users:
|
||||
if self._hs.is_mine_id(user):
|
||||
# Modify SyncHandler._generate_sync_entry_for_presence to call
|
||||
# presence_source.get_new_events with an empty `from_key` if
|
||||
# that user's ID were in a list modified by ModuleApi somewhere.
|
||||
# That user would then get all presence state on next incremental sync.
|
||||
|
||||
# Force a presence initial_sync for this user next time
|
||||
self._send_full_presence_to_local_users.add(user)
|
||||
local_users.add(user)
|
||||
else:
|
||||
# Retrieve presence state for currently online users that this user
|
||||
# is considered interested in
|
||||
presence_events, _ = await self._presence_stream.get_new_events(
|
||||
UserID.from_string(user), from_key=None, include_offline=False
|
||||
)
|
||||
remote_users.add(user)
|
||||
|
||||
# Send to remote destinations.
|
||||
# We pull out the presence handler here to break a cyclic
|
||||
# dependency between the presence router and module API.
|
||||
presence_handler = self._hs.get_presence_handler()
|
||||
|
||||
# We pull out the presence handler here to break a cyclic
|
||||
# dependency between the presence router and module API.
|
||||
presence_handler = self._hs.get_presence_handler()
|
||||
await presence_handler.maybe_send_presence_to_interested_destinations(
|
||||
presence_events
|
||||
)
|
||||
if local_users:
|
||||
# Force a presence initial_sync for these users next time they sync.
|
||||
await presence_handler.send_full_presence_to_users(local_users)
|
||||
|
||||
for user in remote_users:
|
||||
# Retrieve presence state for currently online users that this user
|
||||
# is considered interested in.
|
||||
presence_events, _ = await self._presence_stream.get_new_events(
|
||||
UserID.from_string(user), from_key=None, include_offline=False
|
||||
)
|
||||
|
||||
# Send to remote destinations.
|
||||
destination = UserID.from_string(user).domain
|
||||
presence_handler.get_federation_queue().send_presence_to_destinations(
|
||||
presence_events, destination
|
||||
)
|
||||
|
||||
|
||||
class PublicRoomListManager:
|
||||
|
@ -12,12 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import email.mime.multipart
|
||||
import email.utils
|
||||
import logging
|
||||
import urllib.parse
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
|
||||
|
||||
import bleach
|
||||
@ -27,7 +23,6 @@ from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.emailconfig import EmailSubjectConfig
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push.presentable_names import (
|
||||
calculate_room_name,
|
||||
descriptor_from_member_events,
|
||||
@ -108,7 +103,7 @@ class Mailer:
|
||||
self.template_html = template_html
|
||||
self.template_text = template_text
|
||||
|
||||
self.sendmail = self.hs.get_sendmail()
|
||||
self.send_email_handler = hs.get_send_email_handler()
|
||||
self.store = self.hs.get_datastore()
|
||||
self.state_store = self.hs.get_storage().state
|
||||
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||
@ -310,17 +305,6 @@ class Mailer:
|
||||
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Send an email with the given information and template text"""
|
||||
try:
|
||||
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
|
||||
except TypeError:
|
||||
from_string = self.hs.config.email_notif_from
|
||||
|
||||
raw_from = email.utils.parseaddr(from_string)[1]
|
||||
raw_to = email.utils.parseaddr(email_address)[1]
|
||||
|
||||
if raw_to == "":
|
||||
raise RuntimeError("Invalid 'to' address")
|
||||
|
||||
template_vars = {
|
||||
"app_name": self.app_name,
|
||||
"server_name": self.hs.config.server.server_name,
|
||||
@ -329,35 +313,14 @@ class Mailer:
|
||||
template_vars.update(extra_template_vars)
|
||||
|
||||
html_text = self.template_html.render(**template_vars)
|
||||
html_part = MIMEText(html_text, "html", "utf8")
|
||||
|
||||
plain_text = self.template_text.render(**template_vars)
|
||||
text_part = MIMEText(plain_text, "plain", "utf8")
|
||||
|
||||
multipart_msg = MIMEMultipart("alternative")
|
||||
multipart_msg["Subject"] = subject
|
||||
multipart_msg["From"] = from_string
|
||||
multipart_msg["To"] = email_address
|
||||
multipart_msg["Date"] = email.utils.formatdate()
|
||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
logger.info("Sending email to %s" % email_address)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
raw_from,
|
||||
raw_to,
|
||||
multipart_msg.as_string().encode("utf8"),
|
||||
reactor=self.hs.get_reactor(),
|
||||
port=self.hs.config.email_smtp_port,
|
||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
||||
username=self.hs.config.email_smtp_user,
|
||||
password=self.hs.config.email_smtp_pass,
|
||||
requireTransportSecurity=self.hs.config.require_transport_security,
|
||||
)
|
||||
await self.send_email_handler.send_email(
|
||||
email_address=email_address,
|
||||
subject=subject,
|
||||
app_name=self.app_name,
|
||||
html=html_text,
|
||||
text=plain_text,
|
||||
)
|
||||
|
||||
async def _get_room_vars(
|
||||
|
@ -87,6 +87,7 @@ REQUIREMENTS = [
|
||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||
# with the latest security patches.
|
||||
"cryptography>=3.4.7",
|
||||
"ijson>=3.0",
|
||||
]
|
||||
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
|
@ -73,6 +73,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||
{
|
||||
"state": { ... },
|
||||
"ignore_status_msg": false,
|
||||
"force_notify": false
|
||||
}
|
||||
|
||||
200 OK
|
||||
@ -91,17 +92,23 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||
self._presence_handler = hs.get_presence_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id, state, ignore_status_msg=False):
|
||||
async def _serialize_payload(
|
||||
user_id, state, ignore_status_msg=False, force_notify=False
|
||||
):
|
||||
return {
|
||||
"state": state,
|
||||
"ignore_status_msg": ignore_status_msg,
|
||||
"force_notify": force_notify,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
await self._presence_handler.set_state(
|
||||
UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
|
||||
UserID.from_string(user_id),
|
||||
content["state"],
|
||||
content["ignore_status_msg"],
|
||||
content["force_notify"],
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -24,7 +24,7 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.client_ip_last_seen = LruCache(
|
||||
cache_name="client_ip_last_seen", keylen=4, max_size=50000
|
||||
cache_name="client_ip_last_seen", max_size=50000
|
||||
) # type: LruCache[tuple, int]
|
||||
|
||||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||
|
@ -1,21 +0,0 @@
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.databases.main.transactions import TransactionStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
||||
class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
|
||||
pass
|
@ -54,7 +54,6 @@ class SendServerNoticeServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.snm = hs.get_server_notices_manager()
|
||||
|
||||
def register(self, json_resource: HttpServer):
|
||||
PATTERN = "/send_server_notice"
|
||||
@ -77,7 +76,10 @@ class SendServerNoticeServlet(RestServlet):
|
||||
event_type = body.get("type", EventTypes.Message)
|
||||
state_key = body.get("state_key")
|
||||
|
||||
if not self.snm.is_enabled():
|
||||
# We grab the server notices manager here as its initialisation has a check for worker processes,
|
||||
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
|
||||
# admin api).
|
||||
if not self.hs.get_server_notices_manager().is_enabled():
|
||||
raise SynapseError(400, "Server notices are not enabled on this server")
|
||||
|
||||
user_id = body["user_id"]
|
||||
@ -85,7 +87,7 @@ class SendServerNoticeServlet(RestServlet):
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
raise SynapseError(400, "Server notices can only be sent to local users")
|
||||
|
||||
event = await self.snm.send_notice(
|
||||
event = await self.hs.get_server_notices_manager().send_notice(
|
||||
user_id=body["user_id"],
|
||||
type=event_type,
|
||||
state_key=state_key,
|
||||
|
@ -48,11 +48,6 @@ class LocalKey(Resource):
|
||||
"key": # base64 encoded NACL verification key.
|
||||
}
|
||||
},
|
||||
"tls_fingerprints": [ # Fingerprints of the TLS certs this server uses.
|
||||
{
|
||||
"sha256": # base64 encoded sha256 fingerprint of the X509 cert
|
||||
},
|
||||
],
|
||||
"signatures": {
|
||||
"this.server.example.com": {
|
||||
"algorithm:version": # NACL signature for this server
|
||||
@ -89,14 +84,11 @@ class LocalKey(Resource):
|
||||
"expired_ts": key.expired_ts,
|
||||
}
|
||||
|
||||
tls_fingerprints = self.config.tls_fingerprints
|
||||
|
||||
json_object = {
|
||||
"valid_until_ts": self.valid_until_ts,
|
||||
"server_name": self.config.server_name,
|
||||
"verify_keys": verify_keys,
|
||||
"old_verify_keys": old_verify_keys,
|
||||
"tls_fingerprints": tls_fingerprints,
|
||||
}
|
||||
for key in self.config.signing_key:
|
||||
json_object = sign_json(json_object, self.config.server_name, key)
|
||||
|
@ -73,9 +73,6 @@ class RemoteKey(DirectServeJsonResource):
|
||||
"expired_ts": 0, # when the key stop being used.
|
||||
}
|
||||
}
|
||||
"tls_fingerprints": [
|
||||
{ "sha256": # fingerprint }
|
||||
]
|
||||
"signatures": {
|
||||
"remote.server.example.com": {...}
|
||||
"this.server.example.com": {...}
|
||||
|
@ -76,6 +76,8 @@ class MediaRepository:
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
|
||||
Thumbnailer.set_limits(self.max_image_pixels)
|
||||
|
||||
self.primary_base_path = hs.config.media_store_path # type: str
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
||||
|
||||
|
@ -40,6 +40,10 @@ class Thumbnailer:
|
||||
|
||||
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG", "image/webp": "WEBP"}
|
||||
|
||||
@staticmethod
|
||||
def set_limits(max_image_pixels: int):
|
||||
Image.MAX_IMAGE_PIXELS = max_image_pixels
|
||||
|
||||
def __init__(self, input_path: str):
|
||||
try:
|
||||
self.image = Image.open(input_path)
|
||||
@ -47,6 +51,11 @@ class Thumbnailer:
|
||||
# If an error occurs opening the image, a thumbnail won't be able to
|
||||
# be generated.
|
||||
raise ThumbnailError from e
|
||||
except Image.DecompressionBombError as e:
|
||||
# If an image decompression bomb error occurs opening the image,
|
||||
# then the image exceeds the pixel limit and a thumbnail won't
|
||||
# be able to be generated.
|
||||
raise ThumbnailError from e
|
||||
|
||||
self.width, self.height = self.image.size
|
||||
self.transpose_method = None
|
||||
|
@ -104,6 +104,7 @@ from synapse.handlers.room_list import RoomListHandler
|
||||
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
|
||||
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||
from synapse.handlers.search import SearchHandler
|
||||
from synapse.handlers.send_email import SendEmailHandler
|
||||
from synapse.handlers.set_password import SetPasswordHandler
|
||||
from synapse.handlers.space_summary import SpaceSummaryHandler
|
||||
from synapse.handlers.sso import SsoHandler
|
||||
@ -549,6 +550,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def get_search_handler(self) -> SearchHandler:
|
||||
return SearchHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_send_email_handler(self) -> SendEmailHandler:
|
||||
return SendEmailHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_set_password_handler(self) -> SetPasswordHandler:
|
||||
return SetPasswordHandler(self)
|
||||
|
@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from abc import ABCMeta
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
|
||||
|
||||
@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = database.engine
|
||||
self.db_pool = database
|
||||
self.rand = random.SystemRandom()
|
||||
|
||||
def process_replication_rows(
|
||||
self,
|
||||
|
@ -67,7 +67,7 @@ from .state import StateStore
|
||||
from .stats import StatsStore
|
||||
from .stream import StreamStore
|
||||
from .tags import TagsStore
|
||||
from .transactions import TransactionStore
|
||||
from .transactions import TransactionWorkerStore
|
||||
from .ui_auth import UIAuthStore
|
||||
from .user_directory import UserDirectoryStore
|
||||
from .user_erasure_store import UserErasureStore
|
||||
@ -83,7 +83,7 @@ class DataStore(
|
||||
StreamStore,
|
||||
ProfileStore,
|
||||
PresenceStore,
|
||||
TransactionStore,
|
||||
TransactionWorkerStore,
|
||||
DirectoryStore,
|
||||
KeyStore,
|
||||
StateStore,
|
||||
|
@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
|
||||
self.client_ip_last_seen = LruCache(
|
||||
cache_name="client_ip_last_seen", keylen=4, max_size=50000
|
||||
cache_name="client_ip_last_seen", max_size=50000
|
||||
)
|
||||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids",
|
||||
)
|
||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
# the device exists.
|
||||
self.device_id_exists_cache = LruCache(
|
||||
cache_name="device_id_exists", keylen=2, max_size=10000
|
||||
cache_name="device_id_exists", max_size=10000
|
||||
)
|
||||
|
||||
async def store_device(
|
||||
|
@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||
num_args=1,
|
||||
)
|
||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str]
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||
self,
|
||||
txn: Connection,
|
||||
user_ids: List[str],
|
||||
user_ids: Iterable[str],
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
|
@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
self._get_event_cache = LruCache(
|
||||
cache_name="*getEvent*",
|
||||
keylen=3,
|
||||
max_size=hs.config.caches.event_cache_size,
|
||||
)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore):
|
||||
"""
|
||||
keys = {}
|
||||
|
||||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
|
||||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
|
||||
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
||||
|
||||
# batch_iter always returns tuples so it's safe to do len(batch)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
|
||||
|
||||
from synapse.api.presence import PresenceState, UserPresenceState
|
||||
from synapse.replication.tcp.streams import PresenceStream
|
||||
@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
|
||||
db_conn, "presence_stream", "stream_id"
|
||||
)
|
||||
|
||||
self.hs = hs
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
|
||||
@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore):
|
||||
)
|
||||
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
|
||||
|
||||
# Delete old rows to stop database from getting really big
|
||||
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
|
||||
|
||||
for states in batch_iter(presence_states, 50):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "user_id", [s.user_id for s in states]
|
||||
)
|
||||
txn.execute(sql + clause, [stream_id] + list(args))
|
||||
|
||||
# Actually insert new rows
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore):
|
||||
],
|
||||
)
|
||||
|
||||
# Delete old rows to stop database from getting really big
|
||||
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
|
||||
|
||||
for states in batch_iter(presence_states, 50):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "user_id", [s.user_id for s in states]
|
||||
)
|
||||
txn.execute(sql + clause, [stream_id] + list(args))
|
||||
|
||||
async def get_all_presence_updates(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||
@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore):
|
||||
|
||||
return {row["user_id"]: UserPresenceState(**row) for row in rows}
|
||||
|
||||
async def should_user_receive_full_presence_with_token(
|
||||
self,
|
||||
user_id: str,
|
||||
from_token: int,
|
||||
) -> bool:
|
||||
"""Check whether the given user should receive full presence using the stream token
|
||||
they're updating from.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check.
|
||||
from_token: The stream token included in their /sync token.
|
||||
|
||||
Returns:
|
||||
True if the user should have full presence sent to them, False otherwise.
|
||||
"""
|
||||
|
||||
def _should_user_receive_full_presence_with_token_txn(txn):
|
||||
sql = """
|
||||
SELECT 1 FROM users_to_send_full_presence_to
|
||||
WHERE user_id = ?
|
||||
AND presence_stream_id >= ?
|
||||
"""
|
||||
txn.execute(sql, (user_id, from_token))
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"should_user_receive_full_presence_with_token",
|
||||
_should_user_receive_full_presence_with_token_txn,
|
||||
)
|
||||
|
||||
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
|
||||
"""Adds to the list of users who should receive a full snapshot of presence
|
||||
upon their next sync.
|
||||
|
||||
Args:
|
||||
user_ids: An iterable of user IDs.
|
||||
"""
|
||||
# Add user entries to the table, updating the presence_stream_id column if the user already
|
||||
# exists in the table.
|
||||
await self.db_pool.simple_upsert_many(
|
||||
table="users_to_send_full_presence_to",
|
||||
key_names=("user_id",),
|
||||
key_values=[(user_id,) for user_id in user_ids],
|
||||
value_names=("presence_stream_id",),
|
||||
# We save the current presence stream ID token along with the user ID entry so
|
||||
# that when a user /sync's, even if they syncing multiple times across separate
|
||||
# devices at different times, each device will receive full presence once - when
|
||||
# the presence stream ID in their sync token is less than the one in the table
|
||||
# for their user ID.
|
||||
value_values=(
|
||||
(self._presence_id_gen.get_current_token(),) for _ in user_ids
|
||||
),
|
||||
desc="add_users_to_send_full_presence_to",
|
||||
)
|
||||
|
||||
async def get_presence_for_all_users(
|
||||
self,
|
||||
include_offline: bool = True,
|
||||
|
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -997,7 +998,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
expiration_ts = now_ms + self._account_validity_period
|
||||
|
||||
if use_delta:
|
||||
expiration_ts = self.rand.randrange(
|
||||
expiration_ts = random.randrange(
|
||||
expiration_ts - self._account_validity_startup_job_max_delta,
|
||||
expiration_ts,
|
||||
)
|
||||
|
@ -16,13 +16,15 @@ import logging
|
||||
from collections import namedtuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage._base import db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
db_binary_type = memoryview
|
||||
|
||||
@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
|
||||
"_TransactionRow", ("response_code", "response_json")
|
||||
)
|
||||
|
||||
SENTINEL = object()
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DestinationRetryTimings:
|
||||
"""The current destination retry timing info for a remote server."""
|
||||
|
||||
# The first time we tried and failed to reach the remote server, in ms.
|
||||
failure_ts: int
|
||||
|
||||
# The last time we tried and failed to reach the remote server, in ms.
|
||||
retry_last_ts: int
|
||||
|
||||
# How long since the last time we tried to reach the remote server before
|
||||
# trying again, in ms.
|
||||
retry_interval: int
|
||||
|
||||
|
||||
class TransactionWorkerStore(SQLBaseStore):
|
||||
class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
|
||||
"_cleanup_transactions", _cleanup_transactions_txn
|
||||
)
|
||||
|
||||
|
||||
class TransactionStore(TransactionWorkerStore):
|
||||
"""A collection of queries for handling PDUs."""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._destination_retry_cache = ExpiringCache(
|
||||
cache_name="get_destination_retry_timings",
|
||||
clock=self._clock,
|
||||
expiry_ms=5 * 60 * 1000,
|
||||
)
|
||||
|
||||
async def get_received_txn_response(
|
||||
self, transaction_id: str, origin: str
|
||||
) -> Optional[Tuple[int, JsonDict]]:
|
||||
@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
|
||||
desc="set_received_txn_response",
|
||||
)
|
||||
|
||||
async def get_destination_retry_timings(self, destination):
|
||||
@cached(max_entries=10000)
|
||||
async def get_destination_retry_timings(
|
||||
self,
|
||||
destination: str,
|
||||
) -> Optional[DestinationRetryTimings]:
|
||||
"""Gets the current retry timings (if any) for a given destination.
|
||||
|
||||
Args:
|
||||
@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
|
||||
Otherwise a dict for the retry scheme
|
||||
"""
|
||||
|
||||
result = self._destination_retry_cache.get(destination, SENTINEL)
|
||||
if result is not SENTINEL:
|
||||
return result
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_destination_retry_timings",
|
||||
self._get_destination_retry_timings,
|
||||
destination,
|
||||
)
|
||||
|
||||
# We don't hugely care about race conditions between getting and
|
||||
# invalidating the cache, since we time out fairly quickly anyway.
|
||||
self._destination_retry_cache[destination] = result
|
||||
return result
|
||||
|
||||
def _get_destination_retry_timings(self, txn, destination):
|
||||
def _get_destination_retry_timings(
|
||||
self, txn, destination: str
|
||||
) -> Optional[DestinationRetryTimings]:
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="destinations",
|
||||
keyvalues={"destination": destination},
|
||||
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
|
||||
retcols=("failure_ts", "retry_last_ts", "retry_interval"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
# check we have a row and retry_last_ts is not null or zero
|
||||
# (retry_last_ts can't be negative)
|
||||
if result and result["retry_last_ts"]:
|
||||
return result
|
||||
return DestinationRetryTimings(**result)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
|
||||
retry_interval: how long until next retry in ms
|
||||
"""
|
||||
|
||||
self._destination_retry_cache.pop(destination, None)
|
||||
if self.database_engine.can_native_upsert:
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_destination_retry_timings",
|
||||
@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
|
||||
|
||||
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_destination_retry_timings, (destination,)
|
||||
)
|
||||
|
||||
def _set_destination_retry_timings_emulated(
|
||||
self, txn, destination, failure_ts, retry_last_ts, retry_interval
|
||||
):
|
||||
@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
|
||||
},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_destination_retry_timings, (destination,)
|
||||
)
|
||||
|
||||
async def store_destination_rooms_entries(
|
||||
self,
|
||||
destinations: Iterable[str],
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
return bool(result)
|
||||
|
||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||
async def are_users_erased(self, user_ids):
|
||||
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
|
||||
"""
|
||||
Checks which users in a list have requested erasure
|
||||
|
||||
Args:
|
||||
user_ids (iterable[str]): full user id to check
|
||||
user_ids: full user ids to check
|
||||
|
||||
Returns:
|
||||
dict[str, bool]:
|
||||
for each user, whether the user has requested erasure.
|
||||
for each user, whether the user has requested erasure.
|
||||
"""
|
||||
# this serves the dual purpose of (a) making sure we can do len and
|
||||
# iterate it multiple times, and (b) avoiding duplicates.
|
||||
user_ids = tuple(set(user_ids))
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="erased_users",
|
||||
column="user_id",
|
||||
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Add a table that keeps track of a list of users who should, upon their next
|
||||
-- sync request, receive presence for all currently online users that they are
|
||||
-- "interested" in.
|
||||
|
||||
-- The motivation for a DB table over an in-memory list is so that this list
|
||||
-- can be added to and retrieved from by any worker. Specifically, we don't
|
||||
-- want to duplicate work across multiple sync workers.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to(
|
||||
-- The user ID to send full presence to.
|
||||
user_id TEXT PRIMARY KEY,
|
||||
-- A presence stream ID token - the current presence stream token when the row was last upserted.
|
||||
-- If a user calls /sync and this token is part of the update they're to receive, we also include
|
||||
-- full user presence in the response.
|
||||
-- This allows multiple devices for a user to receive full presence whenever they next call /sync.
|
||||
presence_stream_id BIGINT,
|
||||
FOREIGN KEY (user_id)
|
||||
REFERENCES users (name)
|
||||
);
|
@ -540,7 +540,7 @@ class StateGroupStorage:
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
A dict from (type, state_key) -> state_event_id
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
|
153
synapse/util/batching_queue.py
Normal file
153
synapse/util/batching_queue.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Hashable,
|
||||
List,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
V = TypeVar("V")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class BatchingQueue(Generic[V, R]):
|
||||
"""A queue that batches up work, calling the provided processing function
|
||||
with all pending work (for a given key).
|
||||
|
||||
The provided processing function will only be called once at a time for each
|
||||
key. It will be called the next reactor tick after `add_to_queue` has been
|
||||
called, and will keep being called until the queue has been drained (for the
|
||||
given key).
|
||||
|
||||
Note that the return value of `add_to_queue` will be the return value of the
|
||||
processing function that processed the given item. This means that the
|
||||
returned value will likely include data for other items that were in the
|
||||
batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
clock: Clock,
|
||||
process_batch_callback: Callable[[List[V]], Awaitable[R]],
|
||||
):
|
||||
self._name = name
|
||||
self._clock = clock
|
||||
|
||||
# The set of keys currently being processed.
|
||||
self._processing_keys = set() # type: Set[Hashable]
|
||||
|
||||
# The currently pending batch of values by key, with a Deferred to call
|
||||
# with the result of the corresponding `_process_batch_callback` call.
|
||||
self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
|
||||
|
||||
# The function to call with batches of values.
|
||||
self._process_batch_callback = process_batch_callback
|
||||
|
||||
LaterGauge(
|
||||
"synapse_util_batching_queue_number_queued",
|
||||
"The number of items waiting in the queue across all keys",
|
||||
labels=("name",),
|
||||
caller=lambda: sum(len(v) for v in self._next_values.values()),
|
||||
)
|
||||
|
||||
LaterGauge(
|
||||
"synapse_util_batching_queue_number_of_keys",
|
||||
"The number of distinct keys that have items queued",
|
||||
labels=("name",),
|
||||
caller=lambda: len(self._next_values),
|
||||
)
|
||||
|
||||
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
||||
"""Adds the value to the queue with the given key, returning the result
|
||||
of the processing function for the batch that included the given value.
|
||||
|
||||
The optional `key` argument allows sharding the queue by some key. The
|
||||
queues will then be processed in parallel, i.e. the process batch
|
||||
function will be called in parallel with batched values from a single
|
||||
key.
|
||||
"""
|
||||
|
||||
# First we create a defer and add it and the value to the list of
|
||||
# pending items.
|
||||
d = defer.Deferred()
|
||||
self._next_values.setdefault(key, []).append((value, d))
|
||||
|
||||
# If we're not currently processing the key fire off a background
|
||||
# process to start processing.
|
||||
if key not in self._processing_keys:
|
||||
run_as_background_process(self._name, self._process_queue, key)
|
||||
|
||||
return await make_deferred_yieldable(d)
|
||||
|
||||
async def _process_queue(self, key: Hashable) -> None:
|
||||
"""A background task to repeatedly pull things off the queue for the
|
||||
given key and call the `self._process_batch_callback` with the values.
|
||||
"""
|
||||
|
||||
try:
|
||||
if key in self._processing_keys:
|
||||
return
|
||||
|
||||
self._processing_keys.add(key)
|
||||
|
||||
while True:
|
||||
# We purposefully wait a reactor tick to allow us to batch
|
||||
# together requests that we're about to receive. A common
|
||||
# pattern is to call `add_to_queue` multiple times at once, and
|
||||
# deferring to the next reactor tick allows us to batch all of
|
||||
# those up.
|
||||
await self._clock.sleep(0)
|
||||
|
||||
next_values = self._next_values.pop(key, [])
|
||||
if not next_values:
|
||||
# We've exhausted the queue.
|
||||
break
|
||||
|
||||
try:
|
||||
values = [value for value, _ in next_values]
|
||||
results = await self._process_batch_callback(values)
|
||||
|
||||
for _, deferred in next_values:
|
||||
with PreserveLoggingContext():
|
||||
deferred.callback(results)
|
||||
|
||||
except Exception as e:
|
||||
for _, deferred in next_values:
|
||||
if deferred.called:
|
||||
continue
|
||||
|
||||
with PreserveLoggingContext():
|
||||
deferred.errback(e)
|
||||
|
||||
finally:
|
||||
self._processing_keys.discard(key)
|
@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]):
|
||||
self,
|
||||
name: str,
|
||||
max_entries: int = 1000,
|
||||
keylen: int = 1,
|
||||
tree: bool = False,
|
||||
iterable: bool = False,
|
||||
apply_cache_factor_from_config: bool = True,
|
||||
@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]):
|
||||
# a Deferred.
|
||||
self.cache = LruCache(
|
||||
max_size=max_entries,
|
||||
keylen=keylen,
|
||||
cache_name=name,
|
||||
cache_type=cache_type,
|
||||
size_callback=(lambda d: len(d) or 1) if iterable else None,
|
||||
|
@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
cache = DeferredCache(
|
||||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
keylen=self.num_args,
|
||||
tree=self.tree,
|
||||
iterable=self.iterable,
|
||||
) # type: DeferredCache[CacheKey, Any]
|
||||
@ -322,8 +321,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||
"""Wraps an existing cache to support bulk fetching of keys.
|
||||
|
||||
Given a list of keys it looks in the cache to find any hits, then passes
|
||||
the list of missing keys to the wrapped function.
|
||||
Given an iterable of keys it looks in the cache to find any hits, then passes
|
||||
the tuple of missing keys to the wrapped function.
|
||||
|
||||
Once wrapped, the function returns a Deferred which resolves to the list
|
||||
of results.
|
||||
@ -437,7 +436,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||
return f
|
||||
|
||||
args_to_call = dict(arg_dict)
|
||||
args_to_call[self.list_name] = list(missing)
|
||||
# copy the missing set before sending it to the callee, to guard against
|
||||
# modification.
|
||||
args_to_call[self.list_name] = tuple(missing)
|
||||
|
||||
cached_defers.append(
|
||||
defer.maybeDeferred(
|
||||
@ -522,14 +523,14 @@ def cachedList(
|
||||
|
||||
Used to do batch lookups for an already created cache. A single argument
|
||||
is specified as a list that is iterated through to lookup keys in the
|
||||
original cache. A new list consisting of the keys that weren't in the cache
|
||||
get passed to the original function, the result of which is stored in the
|
||||
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
|
||||
the cache gets passed to the original function, the result of which is stored in the
|
||||
cache.
|
||||
|
||||
Args:
|
||||
cached_method_name: The name of the single-item lookup method.
|
||||
This is only used to find the cache to use.
|
||||
list_name: The name of the argument that is the list to use to
|
||||
list_name: The name of the argument that is the iterable to use to
|
||||
do batch lookups in the cache.
|
||||
num_args: Number of arguments to use as the key in the cache
|
||||
(including list_name). Defaults to all named parameters.
|
||||
|
@ -34,7 +34,7 @@ from typing_extensions import Literal
|
||||
from synapse.config import cache as cache_config
|
||||
from synapse.util import caches
|
||||
from synapse.util.caches import CacheMetric, register_cache
|
||||
from synapse.util.caches.treecache import TreeCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
|
||||
try:
|
||||
from pympler.asizeof import Asizer
|
||||
@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]):
|
||||
self,
|
||||
max_size: int,
|
||||
cache_name: Optional[str] = None,
|
||||
keylen: int = 1,
|
||||
cache_type: Type[Union[dict, TreeCache]] = dict,
|
||||
size_callback: Optional[Callable] = None,
|
||||
metrics_collection_callback: Optional[Callable[[], None]] = None,
|
||||
@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]):
|
||||
cache_name: The name of this cache, for the prometheus metrics. If unset,
|
||||
no metrics will be reported on this cache.
|
||||
|
||||
keylen: The length of the tuple used as the cache key. Ignored unless
|
||||
cache_type is `TreeCache`.
|
||||
|
||||
cache_type (type):
|
||||
type of underlying cache to be used. Typically one of dict
|
||||
or TreeCache.
|
||||
@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]):
|
||||
popped = cache.pop(key)
|
||||
if popped is None:
|
||||
return
|
||||
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
|
||||
# for each deleted node, we now need to remove it from the linked list
|
||||
# and run its callbacks.
|
||||
for leaf in iterate_tree_cache_entry(popped):
|
||||
delete_node(leaf)
|
||||
|
||||
@synchronized
|
||||
|
@ -1,18 +1,43 @@
|
||||
from typing import Dict
|
||||
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class TreeCacheNode(dict):
|
||||
"""The type of nodes in our tree.
|
||||
|
||||
Has its own type so we can distinguish it from real dicts that are stored at the
|
||||
leaves.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreeCache:
|
||||
"""
|
||||
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
|
||||
efficiently.
|
||||
Keys must be tuples.
|
||||
|
||||
The data structure is a chain of TreeCacheNodes:
|
||||
root = {key_1: {key_2: _value}}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.size = 0
|
||||
self.root = {} # type: Dict
|
||||
self.root = TreeCacheNode()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return self.set(key, value)
|
||||
@ -21,10 +46,23 @@ class TreeCache:
|
||||
return self.get(key, SENTINEL) is not SENTINEL
|
||||
|
||||
def set(self, key, value):
|
||||
if isinstance(value, TreeCacheNode):
|
||||
# this would mean we couldn't tell where our tree ended and the value
|
||||
# started.
|
||||
raise ValueError("Cannot store TreeCacheNodes in a TreeCache")
|
||||
|
||||
node = self.root
|
||||
for k in key[:-1]:
|
||||
node = node.setdefault(k, {})
|
||||
node[key[-1]] = _Entry(value)
|
||||
next_node = node.get(k, SENTINEL)
|
||||
if next_node is SENTINEL:
|
||||
next_node = node[k] = TreeCacheNode()
|
||||
elif not isinstance(next_node, TreeCacheNode):
|
||||
# this suggests that the caller is not being consistent with its key
|
||||
# length.
|
||||
raise ValueError("value conflicts with an existing subtree")
|
||||
node = next_node
|
||||
|
||||
node[key[-1]] = value
|
||||
self.size += 1
|
||||
|
||||
def get(self, key, default=None):
|
||||
@ -33,25 +71,41 @@ class TreeCache:
|
||||
node = node.get(k, None)
|
||||
if node is None:
|
||||
return default
|
||||
return node.get(key[-1], _Entry(default)).value
|
||||
return node.get(key[-1], default)
|
||||
|
||||
def clear(self):
|
||||
self.size = 0
|
||||
self.root = {}
|
||||
self.root = TreeCacheNode()
|
||||
|
||||
def pop(self, key, default=None):
|
||||
"""Remove the given key, or subkey, from the cache
|
||||
|
||||
Args:
|
||||
key: key or subkey to remove.
|
||||
default: value to return if key is not found
|
||||
|
||||
Returns:
|
||||
If the key is not found, 'default'. If the key is complete, the removed
|
||||
value. If the key is partial, the TreeCacheNode corresponding to the part
|
||||
of the tree that was removed.
|
||||
"""
|
||||
# a list of the nodes we have touched on the way down the tree
|
||||
nodes = []
|
||||
|
||||
node = self.root
|
||||
for k in key[:-1]:
|
||||
node = node.get(k, None)
|
||||
nodes.append(node) # don't add the root node
|
||||
if node is None:
|
||||
return default
|
||||
if not isinstance(node, TreeCacheNode):
|
||||
# we've gone off the end of the tree
|
||||
raise ValueError("pop() key too long")
|
||||
nodes.append(node) # don't add the root node
|
||||
popped = node.pop(key[-1], SENTINEL)
|
||||
if popped is SENTINEL:
|
||||
return default
|
||||
|
||||
# working back up the tree, clear out any nodes that are now empty
|
||||
node_and_keys = list(zip(nodes, key))
|
||||
node_and_keys.reverse()
|
||||
node_and_keys.append((self.root, None))
|
||||
@ -61,14 +115,15 @@ class TreeCache:
|
||||
|
||||
if n:
|
||||
break
|
||||
# found an empty node: remove it from its parent, and loop.
|
||||
node_and_keys[i + 1][0].pop(k)
|
||||
|
||||
popped, cnt = _strip_and_count_entires(popped)
|
||||
cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
|
||||
self.size -= cnt
|
||||
return popped
|
||||
|
||||
def values(self):
|
||||
return list(iterate_tree_cache_entry(self.root))
|
||||
return iterate_tree_cache_entry(self.root)
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d):
|
||||
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
|
||||
can contain dicts.
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
if isinstance(d, TreeCacheNode):
|
||||
for value_d in d.values():
|
||||
for value in iterate_tree_cache_entry(value_d):
|
||||
yield value
|
||||
else:
|
||||
if isinstance(d, _Entry):
|
||||
yield d.value
|
||||
else:
|
||||
yield d
|
||||
|
||||
|
||||
class _Entry:
|
||||
__slots__ = ["value"]
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
def _strip_and_count_entires(d):
|
||||
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
|
||||
value or a dictionary with _Entry's replaced by their values.
|
||||
|
||||
Also returns the count of _Entry's
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
cnt = 0
|
||||
for key, value in d.items():
|
||||
v, n = _strip_and_count_entires(value)
|
||||
d[key] = v
|
||||
cnt += n
|
||||
return d, cnt
|
||||
else:
|
||||
return d.value, 1
|
||||
yield d
|
||||
|
@ -17,15 +17,15 @@ import hashlib
|
||||
import unpaddedbase64
|
||||
|
||||
|
||||
def sha256_and_url_safe_base64(input_text):
|
||||
def sha256_and_url_safe_base64(input_text: str) -> str:
|
||||
"""SHA256 hash an input string, encode the digest as url-safe base64, and
|
||||
return
|
||||
|
||||
:param input_text: string to hash
|
||||
:type input_text: str
|
||||
Args:
|
||||
input_text: string to hash
|
||||
|
||||
:returns a sha256 hashed and url-safe base64 encoded digest
|
||||
:rtype: str
|
||||
returns:
|
||||
A sha256 hashed and url-safe base64 encoded digest
|
||||
"""
|
||||
digest = hashlib.sha256(input_text.encode()).digest()
|
||||
return unpaddedbase64.encode_base64(digest, urlsafe=True)
|
||||
|
@ -30,12 +30,12 @@ from typing import (
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
|
||||
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
|
||||
"""batch an iterable up into tuples with a maximum size
|
||||
|
||||
Args:
|
||||
iterable (iterable): the iterable to slice
|
||||
size (int): the maximum batch size
|
||||
iterable: the iterable to slice
|
||||
size: the maximum batch size
|
||||
|
||||
Returns:
|
||||
an iterator over the chunks
|
||||
@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
|
||||
return iter(lambda: tuple(islice(sourceiter, size)), ())
|
||||
|
||||
|
||||
ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
|
||||
|
||||
|
||||
def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
|
||||
def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
|
||||
"""Split the given sequence into chunks of the given size
|
||||
|
||||
The last chunk may be shorter than the given size.
|
||||
|
@ -15,6 +15,7 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import itertools
|
||||
from types import ModuleType
|
||||
from typing import Any, Iterable, Tuple, Type
|
||||
|
||||
import jsonschema
|
||||
@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
|
||||
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = modulename.rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
module_name, clz = modulename.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
provider_class = getattr(module, clz)
|
||||
|
||||
# Load the module config. If None, pass an empty dictionary instead
|
||||
@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
|
||||
return provider_class, provider_config
|
||||
|
||||
|
||||
def load_python_module(location: str):
|
||||
def load_python_module(location: str) -> ModuleType:
|
||||
"""Load a python module, and return a reference to its global namespace
|
||||
|
||||
Args:
|
||||
location (str): path to the module
|
||||
location: path to the module
|
||||
|
||||
Returns:
|
||||
python module object
|
||||
|
@ -17,19 +17,19 @@ import phonenumbers
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
|
||||
def phone_number_to_msisdn(country, number):
|
||||
def phone_number_to_msisdn(country: str, number: str) -> str:
|
||||
"""
|
||||
Takes an ISO-3166-1 2 letter country code and phone number and
|
||||
returns an msisdn representing the canonical version of that
|
||||
phone number.
|
||||
Args:
|
||||
country (str): ISO-3166-1 2 letter country code
|
||||
number (str): Phone number in a national or international format
|
||||
country: ISO-3166-1 2 letter country code
|
||||
number: Phone number in a national or international format
|
||||
|
||||
Returns:
|
||||
(str) The canonical form of the phone number, as an msisdn
|
||||
The canonical form of the phone number, as an msisdn
|
||||
Raises:
|
||||
SynapseError if the number could not be parsed.
|
||||
SynapseError if the number could not be parsed.
|
||||
"""
|
||||
try:
|
||||
phoneNumber = phonenumbers.parse(number, country)
|
||||
|
@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
|
||||
retry_timings = await store.get_destination_retry_timings(destination)
|
||||
|
||||
if retry_timings:
|
||||
failure_ts = retry_timings["failure_ts"]
|
||||
retry_last_ts, retry_interval = (
|
||||
retry_timings["retry_last_ts"],
|
||||
retry_timings["retry_interval"],
|
||||
)
|
||||
failure_ts = retry_timings.failure_ts
|
||||
retry_last_ts = retry_timings.retry_last_ts
|
||||
retry_interval = retry_timings.retry_interval
|
||||
|
||||
now = int(clock.time_msec())
|
||||
|
||||
|
@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Tuple
|
||||
@ -35,26 +35,27 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
|
||||
#
|
||||
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||
|
||||
# random_string and random_string_with_symbols are used for a range of things,
|
||||
# some cryptographically important, some less so. We use SystemRandom to make sure
|
||||
# we get cryptographically-secure randoms.
|
||||
rand = random.SystemRandom()
|
||||
|
||||
|
||||
def random_string(length: int) -> str:
|
||||
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
|
||||
"""Generate a cryptographically secure string of random letters.
|
||||
|
||||
Drawn from the characters: `a-z` and `A-Z`
|
||||
"""
|
||||
return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
|
||||
|
||||
|
||||
def random_string_with_symbols(length: int) -> str:
|
||||
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
|
||||
"""Generate a cryptographically secure string of random letters/numbers/symbols.
|
||||
|
||||
Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
|
||||
"""
|
||||
return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
|
||||
|
||||
|
||||
def is_ascii(s: bytes) -> bool:
|
||||
try:
|
||||
s.decode("ascii").encode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
except UnicodeEncodeError:
|
||||
except UnicodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
102
synctl
102
synctl
@ -24,12 +24,13 @@ import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
import yaml
|
||||
|
||||
from synapse.config import find_config_files
|
||||
|
||||
SYNAPSE = [sys.executable, "-m", "synapse.app.homeserver"]
|
||||
MAIN_PROCESS = "synapse.app.homeserver"
|
||||
|
||||
GREEN = "\x1b[1;32m"
|
||||
YELLOW = "\x1b[1;33m"
|
||||
@ -68,71 +69,37 @@ def abort(message, colour=RED, stream=sys.stderr):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start(configfile: str, daemonize: bool = True) -> bool:
|
||||
"""Attempts to start synapse.
|
||||
def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) -> bool:
|
||||
"""Attempts to start a synapse main or worker process.
|
||||
Args:
|
||||
configfile: path to a yaml synapse config file
|
||||
daemonize: whether to daemonize synapse or keep it attached to the current
|
||||
session
|
||||
pidfile: the pidfile we expect the process to create
|
||||
app: the python module to run
|
||||
config_files: config files to pass to synapse
|
||||
daemonize: if True, will include a --daemonize argument to synapse
|
||||
|
||||
Returns:
|
||||
True if the process started successfully
|
||||
True if the process started successfully or was already running
|
||||
False if there was an error starting the process
|
||||
|
||||
If deamonize is False it will only return once synapse exits.
|
||||
"""
|
||||
|
||||
write("Starting ...")
|
||||
args = SYNAPSE
|
||||
if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
|
||||
print(app + " already running")
|
||||
return True
|
||||
|
||||
args = [sys.executable, "-m", app]
|
||||
for c in config_files:
|
||||
args += ["-c", c]
|
||||
if daemonize:
|
||||
args.extend(["--daemonize", "-c", configfile])
|
||||
else:
|
||||
args.extend(["-c", configfile])
|
||||
args.append("--daemonize")
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
||||
write("started %s(%s)" % (app, ",".join(config_files)), colour=GREEN)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||
colour=RED,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
|
||||
"""Attempts to start a synapse worker.
|
||||
Args:
|
||||
app: name of the worker's appservice
|
||||
configfile: path to a yaml synapse config file
|
||||
worker_configfile: path to worker specific yaml synapse file
|
||||
|
||||
Returns:
|
||||
True if the process started successfully
|
||||
False if there was an error starting the process
|
||||
"""
|
||||
|
||||
args = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
app,
|
||||
"-c",
|
||||
configfile,
|
||||
"-c",
|
||||
worker_configfile,
|
||||
"--daemonize",
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting %s(%r) (exit code: %d); see above for logs"
|
||||
% (app, worker_configfile, e.returncode),
|
||||
"error starting %s(%s) (exit code: %d); see above for logs"
|
||||
% (app, ",".join(config_files), e.returncode),
|
||||
colour=RED,
|
||||
)
|
||||
return False
|
||||
@ -224,10 +191,11 @@ def main():
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name> --report-stats=<yes/no>'\n"
|
||||
% (" ".join(SYNAPSE), options.configfile),
|
||||
f"Config file {configfile} does not exist.\n"
|
||||
f"To generate a config file, run:\n"
|
||||
f" {sys.executable} -m {MAIN_PROCESS}"
|
||||
f" -c {configfile} --generate-config"
|
||||
f" --server-name=<server name> --report-stats=<yes/no>\n",
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
@ -323,7 +291,7 @@ def main():
|
||||
has_stopped = False
|
||||
|
||||
if start_stop_synapse:
|
||||
if not stop(pidfile, "synapse.app.homeserver"):
|
||||
if not stop(pidfile, MAIN_PROCESS):
|
||||
has_stopped = False
|
||||
if not has_stopped and action == "stop":
|
||||
sys.exit(1)
|
||||
@ -346,30 +314,24 @@ def main():
|
||||
if action == "start" or action == "restart":
|
||||
error = False
|
||||
if start_stop_synapse:
|
||||
# Check if synapse is already running
|
||||
if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
|
||||
abort("synapse.app.homeserver already running")
|
||||
|
||||
if not start(configfile, bool(options.daemonize)):
|
||||
if not start(pidfile, MAIN_PROCESS, (configfile,), options.daemonize):
|
||||
error = True
|
||||
|
||||
for worker in workers:
|
||||
env = os.environ.copy()
|
||||
|
||||
# Skip starting a worker if its already running
|
||||
if os.path.exists(worker.pidfile) and pid_running(
|
||||
int(open(worker.pidfile).read())
|
||||
):
|
||||
print(worker.app + " already running")
|
||||
continue
|
||||
|
||||
if worker.cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
|
||||
|
||||
for cache_name, factor in worker.cache_factors.items():
|
||||
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
|
||||
|
||||
if not start_worker(worker.app, configfile, worker.configfile):
|
||||
if not start(
|
||||
worker.pidfile,
|
||||
worker.app,
|
||||
(configfile, worker.configfile),
|
||||
options.daemonize,
|
||||
):
|
||||
error = True
|
||||
|
||||
# Reset env back to the original
|
||||
|
@ -302,11 +302,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||
)
|
||||
|
||||
# Check that the expected presence updates were sent
|
||||
expected_users = [
|
||||
# We explicitly compare using sets as we expect that calling
|
||||
# module_api.send_local_online_presence_to will create a presence
|
||||
# update that is a duplicate of the specified user's current presence.
|
||||
# These are sent to clients and will be picked up below, thus we use a
|
||||
# set to deduplicate. We're just interested that non-offline updates were
|
||||
# sent out for each user ID.
|
||||
expected_users = {
|
||||
self.other_user_id,
|
||||
self.presence_receiving_user_one_id,
|
||||
self.presence_receiving_user_two_id,
|
||||
]
|
||||
}
|
||||
found_users = set()
|
||||
|
||||
calls = (
|
||||
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
||||
@ -326,12 +333,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||
# EDUs can contain multiple presence updates
|
||||
for presence_update in edu["content"]["push"]:
|
||||
# Check for presence updates that contain the user IDs we're after
|
||||
expected_users.remove(presence_update["user_id"])
|
||||
found_users.add(presence_update["user_id"])
|
||||
|
||||
# Ensure that no offline states are being sent out
|
||||
self.assertNotEqual(presence_update["presence"], "offline")
|
||||
|
||||
self.assertEqual(len(expected_users), 0)
|
||||
self.assertEqual(found_users, expected_users)
|
||||
|
||||
|
||||
def send_presence_update(
|
||||
|
@ -32,13 +32,19 @@ from synapse.handlers.presence import (
|
||||
handle_timeout,
|
||||
handle_update,
|
||||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import room
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class PresenceUpdateTestCase(unittest.TestCase):
|
||||
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [admin.register_servlets]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.store = homeserver.get_datastore()
|
||||
|
||||
def test_offline_to_online(self):
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
@ -292,6 +298,45 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
def test_persisting_presence_updates(self):
|
||||
"""Tests that the latest presence state for each user is persisted correctly"""
|
||||
# Create some test users and presence states for them
|
||||
presence_states = []
|
||||
for i in range(5):
|
||||
user_id = self.register_user(f"user_{i}", "password")
|
||||
|
||||
presence_state = UserPresenceState(
|
||||
user_id=user_id,
|
||||
state="online",
|
||||
last_active_ts=1,
|
||||
last_federation_update_ts=1,
|
||||
last_user_sync_ts=1,
|
||||
status_msg="I'm online!",
|
||||
currently_active=True,
|
||||
)
|
||||
presence_states.append(presence_state)
|
||||
|
||||
# Persist these presence updates to the database
|
||||
self.get_success(self.store.update_presence(presence_states))
|
||||
|
||||
# Check that each update is present in the database
|
||||
db_presence_states = self.get_success(
|
||||
self.store.get_all_presence_updates(
|
||||
instance_name="master",
|
||||
last_id=0,
|
||||
current_id=len(presence_states) + 1,
|
||||
limit=len(presence_states),
|
||||
)
|
||||
)
|
||||
|
||||
# Extract presence update user ID and state information into lists of tuples
|
||||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
|
||||
presence_states = [(ps.user_id, ps.state) for ps in presence_states]
|
||||
|
||||
# Compare what we put into the storage with what we got out.
|
||||
# They should be identical.
|
||||
self.assertEqual(presence_states, db_presence_states)
|
||||
|
||||
|
||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||
def test_idle_timer(self):
|
||||
|
@ -89,14 +89,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
self.event_source = hs.get_event_sources().sources["typing"]
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
retry_timings_res = {
|
||||
"destination": "",
|
||||
"retry_last_ts": 0,
|
||||
"retry_interval": 0,
|
||||
"failure_ts": None,
|
||||
}
|
||||
self.datastore.get_destination_retry_timings = Mock(
|
||||
return_value=defer.succeed(retry_timings_res)
|
||||
return_value=defer.succeed(None)
|
||||
)
|
||||
|
||||
self.datastore.get_device_updates_by_remote = Mock(
|
||||
|
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.units import Transaction
|
||||
@ -22,11 +24,13 @@ from synapse.rest.client.v1 import login, presence, room
|
||||
from synapse.types import create_requester
|
||||
|
||||
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.test_utils.event_injection import inject_member_event
|
||||
from tests.unittest import FederatingHomeserverTestCase, override_config
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class ModuleApiTestCase(FederatingHomeserverTestCase):
|
||||
class ModuleApiTestCase(HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
@ -217,97 +221,16 @@ class ModuleApiTestCase(FederatingHomeserverTestCase):
|
||||
)
|
||||
self.assertFalse(is_in_public_rooms)
|
||||
|
||||
# The ability to send federation is required by send_local_online_presence_to.
|
||||
@override_config({"send_federation": True})
|
||||
def test_send_local_online_presence_to(self):
|
||||
"""Tests that send_local_presence_to_users sends local online presence to local users."""
|
||||
# Create a user who will send presence updates
|
||||
self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
|
||||
self.presence_receiver_tok = self.login("presence_receiver", "monkey")
|
||||
|
||||
# And another user that will send presence updates out
|
||||
self.presence_sender_id = self.register_user("presence_sender", "monkey")
|
||||
self.presence_sender_tok = self.login("presence_sender", "monkey")
|
||||
|
||||
# Put them in a room together so they will receive each other's presence updates
|
||||
room_id = self.helper.create_room_as(
|
||||
self.presence_receiver_id,
|
||||
tok=self.presence_receiver_tok,
|
||||
)
|
||||
self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
|
||||
|
||||
# Presence sender comes online
|
||||
send_presence_update(
|
||||
self,
|
||||
self.presence_sender_id,
|
||||
self.presence_sender_tok,
|
||||
"online",
|
||||
"I'm online!",
|
||||
)
|
||||
|
||||
# Presence receiver should have received it
|
||||
presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
|
||||
self.assertEqual(len(presence_updates), 1)
|
||||
|
||||
presence_update = presence_updates[0] # type: UserPresenceState
|
||||
self.assertEqual(presence_update.user_id, self.presence_sender_id)
|
||||
self.assertEqual(presence_update.state, "online")
|
||||
|
||||
# Syncing again should result in no presence updates
|
||||
presence_updates, sync_token = sync_presence(
|
||||
self, self.presence_receiver_id, sync_token
|
||||
)
|
||||
self.assertEqual(len(presence_updates), 0)
|
||||
|
||||
# Trigger sending local online presence
|
||||
self.get_success(
|
||||
self.module_api.send_local_online_presence_to(
|
||||
[
|
||||
self.presence_receiver_id,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Presence receiver should have received online presence again
|
||||
presence_updates, sync_token = sync_presence(
|
||||
self, self.presence_receiver_id, sync_token
|
||||
)
|
||||
self.assertEqual(len(presence_updates), 1)
|
||||
|
||||
presence_update = presence_updates[0] # type: UserPresenceState
|
||||
self.assertEqual(presence_update.user_id, self.presence_sender_id)
|
||||
self.assertEqual(presence_update.state, "online")
|
||||
|
||||
# Presence sender goes offline
|
||||
send_presence_update(
|
||||
self,
|
||||
self.presence_sender_id,
|
||||
self.presence_sender_tok,
|
||||
"offline",
|
||||
"I slink back into the darkness.",
|
||||
)
|
||||
|
||||
# Trigger sending local online presence
|
||||
self.get_success(
|
||||
self.module_api.send_local_online_presence_to(
|
||||
[
|
||||
self.presence_receiver_id,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Presence receiver should *not* have received offline state
|
||||
presence_updates, sync_token = sync_presence(
|
||||
self, self.presence_receiver_id, sync_token
|
||||
)
|
||||
self.assertEqual(len(presence_updates), 0)
|
||||
# Test sending local online presence to users from the main process
|
||||
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
|
||||
|
||||
@override_config({"send_federation": True})
|
||||
def test_send_local_online_presence_to_federation(self):
|
||||
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
|
||||
# Create a user who will send presence updates
|
||||
self.presence_sender_id = self.register_user("presence_sender", "monkey")
|
||||
self.presence_sender_tok = self.login("presence_sender", "monkey")
|
||||
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
|
||||
self.presence_sender_tok = self.login("presence_sender1", "monkey")
|
||||
|
||||
# And a room they're a part of
|
||||
room_id = self.helper.create_room_as(
|
||||
@ -374,3 +297,209 @@ class ModuleApiTestCase(FederatingHomeserverTestCase):
|
||||
found_update = True
|
||||
|
||||
self.assertTrue(found_update)
|
||||
|
||||
|
||||
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""For testing ModuleApi functionality in a multi-worker setup"""
|
||||
|
||||
# Testing stream ID replication from the main to worker processes requires postgres
|
||||
# (due to needing `MultiWriterIdGenerator`).
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
presence.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
conf = super().default_config()
|
||||
conf["redis"] = {"enabled": "true"}
|
||||
conf["stream_writers"] = {"presence": ["presence_writer"]}
|
||||
conf["instance_map"] = {
|
||||
"presence_writer": {"host": "testserv", "port": 1001},
|
||||
}
|
||||
return conf
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.module_api = homeserver.get_module_api()
|
||||
self.sync_handler = homeserver.get_sync_handler()
|
||||
|
||||
def test_send_local_online_presence_to_workers(self):
|
||||
# Test sending local online presence to users from a worker process
|
||||
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
|
||||
|
||||
|
||||
def _test_sending_local_online_presence_to_local_user(
|
||||
test_case: HomeserverTestCase, test_with_workers: bool = False
|
||||
):
|
||||
"""Tests that send_local_presence_to_users sends local online presence to local users.
|
||||
|
||||
This simultaneously tests two different usecases:
|
||||
* Testing that this method works when either called from a worker or the main process.
|
||||
- We test this by calling this method from both a TestCase that runs in monolith mode, and one that
|
||||
runs with a main and generic_worker.
|
||||
* Testing that multiple devices syncing simultaneously will all receive a snapshot of local,
|
||||
online presence - but only once per device.
|
||||
|
||||
Args:
|
||||
test_with_workers: If True, this method will call ModuleApi.send_local_online_presence_to on a
|
||||
worker process. The test users will still sync with the main process. The purpose of testing
|
||||
with a worker is to check whether a Synapse module running on a worker can inform other workers/
|
||||
the main process that they should include additional presence when a user next syncs.
|
||||
"""
|
||||
if test_with_workers:
|
||||
# Create a worker process to make module_api calls against
|
||||
worker_hs = test_case.make_worker_hs(
|
||||
"synapse.app.generic_worker", {"worker_name": "presence_writer"}
|
||||
)
|
||||
|
||||
# Create a user who will send presence updates
|
||||
test_case.presence_receiver_id = test_case.register_user(
|
||||
"presence_receiver1", "monkey"
|
||||
)
|
||||
test_case.presence_receiver_tok = test_case.login("presence_receiver1", "monkey")
|
||||
|
||||
# And another user that will send presence updates out
|
||||
test_case.presence_sender_id = test_case.register_user("presence_sender2", "monkey")
|
||||
test_case.presence_sender_tok = test_case.login("presence_sender2", "monkey")
|
||||
|
||||
# Put them in a room together so they will receive each other's presence updates
|
||||
room_id = test_case.helper.create_room_as(
|
||||
test_case.presence_receiver_id,
|
||||
tok=test_case.presence_receiver_tok,
|
||||
)
|
||||
test_case.helper.join(
|
||||
room_id, test_case.presence_sender_id, tok=test_case.presence_sender_tok
|
||||
)
|
||||
|
||||
# Presence sender comes online
|
||||
send_presence_update(
|
||||
test_case,
|
||||
test_case.presence_sender_id,
|
||||
test_case.presence_sender_tok,
|
||||
"online",
|
||||
"I'm online!",
|
||||
)
|
||||
|
||||
# Presence receiver should have received it
|
||||
presence_updates, sync_token = sync_presence(
|
||||
test_case, test_case.presence_receiver_id
|
||||
)
|
||||
test_case.assertEqual(len(presence_updates), 1)
|
||||
|
||||
presence_update = presence_updates[0] # type: UserPresenceState
|
||||
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
|
||||
test_case.assertEqual(presence_update.state, "online")
|
||||
|
||||
if test_with_workers:
|
||||
# Replicate the current sync presence token from the main process to the worker process.
|
||||
# We need to do this so that the worker process knows the current presence stream ID to
|
||||
# insert into the database when we call ModuleApi.send_local_online_presence_to.
|
||||
test_case.replicate()
|
||||
|
||||
# Syncing again should result in no presence updates
|
||||
presence_updates, sync_token = sync_presence(
|
||||
test_case, test_case.presence_receiver_id, sync_token
|
||||
)
|
||||
test_case.assertEqual(len(presence_updates), 0)
|
||||
|
||||
# We do an (initial) sync with a second "device" now, getting a new sync token.
|
||||
# We'll use this in a moment.
|
||||
_, sync_token_second_device = sync_presence(
|
||||
test_case, test_case.presence_receiver_id
|
||||
)
|
||||
|
||||
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
|
||||
if test_with_workers:
|
||||
module_api_to_use = worker_hs.get_module_api()
|
||||
else:
|
||||
module_api_to_use = test_case.module_api
|
||||
|
||||
# Trigger sending local online presence. We expect this information
|
||||
# to be saved to the database where all processes can access it.
|
||||
# Note that we're syncing via the master.
|
||||
d = module_api_to_use.send_local_online_presence_to(
|
||||
[
|
||||
test_case.presence_receiver_id,
|
||||
]
|
||||
)
|
||||
d = defer.ensureDeferred(d)
|
||||
|
||||
if test_with_workers:
|
||||
# In order for the required presence_set_state replication request to occur between the
|
||||
# worker and main process, we need to pump the reactor. Otherwise, the coordinator that
|
||||
# reads the request on the main process won't do so, and the request will time out.
|
||||
while not d.called:
|
||||
test_case.reactor.advance(0.1)
|
||||
|
||||
test_case.get_success(d)
|
||||
|
||||
# The presence receiver should have received online presence again.
|
||||
presence_updates, sync_token = sync_presence(
|
||||
test_case, test_case.presence_receiver_id, sync_token
|
||||
)
|
||||
test_case.assertEqual(len(presence_updates), 1)
|
||||
|
||||
presence_update = presence_updates[0] # type: UserPresenceState
|
||||
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
|
||||
test_case.assertEqual(presence_update.state, "online")
|
||||
|
||||
# We attempt to sync with the second sync token we received above - just to check that
|
||||
# multiple syncing devices will each receive the necessary online presence.
|
||||
presence_updates, sync_token_second_device = sync_presence(
|
||||
test_case, test_case.presence_receiver_id, sync_token_second_device
|
||||
)
|
||||
test_case.assertEqual(len(presence_updates), 1)
|
||||
|
||||
presence_update = presence_updates[0] # type: UserPresenceState
|
||||
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
|
||||
test_case.assertEqual(presence_update.state, "online")
|
||||
|
||||
# However, if we now sync with either "device", we won't receive another burst of online presence
|
||||
# until the API is called again sometime in the future
|
||||
presence_updates, sync_token = sync_presence(
|
||||
test_case, test_case.presence_receiver_id, sync_token
|
||||
)
|
||||
|
||||
# Now we check that we don't receive *offline* updates using ModuleApi.send_local_online_presence_to.
|
||||
|
||||
# Presence sender goes offline
|
||||
send_presence_update(
|
||||
test_case,
|
||||
test_case.presence_sender_id,
|
||||
test_case.presence_sender_tok,
|
||||
"offline",
|
||||
"I slink back into the darkness.",
|
||||
)
|
||||
|
||||
# Presence receiver should have received the updated, offline state
|
||||
presence_updates, sync_token = sync_presence(
|
||||
test_case, test_case.presence_receiver_id, sync_token
|
||||
)
|
||||
test_case.assertEqual(len(presence_updates), 1)
|
||||
|
||||
# Now trigger sending local online presence.
|
||||
d = module_api_to_use.send_local_online_presence_to(
|
||||
[
|
||||
test_case.presence_receiver_id,
|
||||
]
|
||||
)
|
||||
d = defer.ensureDeferred(d)
|
||||
|
||||
if test_with_workers:
|
||||
# In order for the required presence_set_state replication request to occur between the
|
||||
# worker and main process, we need to pump the reactor. Otherwise, the coordinator that
|
||||
# reads the request on the main process won't do so, and the request will time out.
|
||||
while not d.called:
|
||||
test_case.reactor.advance(0.1)
|
||||
|
||||
test_case.get_success(d)
|
||||
|
||||
# Presence receiver should *not* have received offline state
|
||||
presence_updates, sync_token = sync_presence(
|
||||
test_case, test_case.presence_receiver_id, sync_token
|
||||
)
|
||||
test_case.assertEqual(len(presence_updates), 0)
|
||||
|
@ -30,7 +30,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""Checks event persisting sharding works"""
|
||||
|
||||
# Event persister sharding requires postgres (due to needing
|
||||
# `MutliWriterIdGenerator`).
|
||||
# `MultiWriterIdGenerator`).
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.databases.main.transactions import DestinationRetryTimings
|
||||
from synapse.util.retryutils import MAX_RETRY_INTERVAL
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
@ -36,8 +37,11 @@ class TransactionStoreTestCase(HomeserverTestCase):
|
||||
d = self.store.get_destination_retry_timings("example.com")
|
||||
r = self.get_success(d)
|
||||
|
||||
self.assert_dict(
|
||||
{"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r
|
||||
self.assertEqual(
|
||||
DestinationRetryTimings(
|
||||
retry_last_ts=50, retry_interval=100, failure_ts=1000
|
||||
),
|
||||
r,
|
||||
)
|
||||
|
||||
def test_initial_set_transactions(self):
|
||||
|
@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
with LoggingContext("c1") as c1:
|
||||
obj = Cls()
|
||||
obj.mock.return_value = {10: "fish", 20: "chips"}
|
||||
|
||||
# start the lookup off
|
||||
d1 = obj.list_fn([10, 20], 2)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
r = yield d1
|
||||
self.assertEqual(current_context(), c1)
|
||||
obj.mock.assert_called_once_with([10, 20], 2)
|
||||
obj.mock.assert_called_once_with((10, 20), 2)
|
||||
self.assertEqual(r, {10: "fish", 20: "chips"})
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = {30: "peas"}
|
||||
r = yield obj.list_fn([20, 30], 2)
|
||||
obj.mock.assert_called_once_with([30], 2)
|
||||
obj.mock.assert_called_once_with((30,), 2)
|
||||
self.assertEqual(r, {20: "chips", 30: "peas"})
|
||||
obj.mock.reset_mock()
|
||||
|
||||
@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
obj.mock.assert_not_called()
|
||||
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
|
||||
|
||||
# we should also be able to use a (single-use) iterable, and should
|
||||
# deduplicate the keys
|
||||
obj.mock.reset_mock()
|
||||
obj.mock.return_value = {40: "gravy"}
|
||||
iterable = (x for x in [10, 40, 40])
|
||||
r = yield obj.list_fn(iterable, 2)
|
||||
obj.mock.assert_called_once_with((40,), 2)
|
||||
self.assertEqual(r, {10: "fish", 40: "gravy"})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invalidate(self):
|
||||
"""Make sure that invalidation callbacks are called."""
|
||||
@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
# cache miss
|
||||
obj.mock.return_value = {10: "fish", 20: "chips"}
|
||||
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||
obj.mock.assert_called_once_with([10, 20], 2)
|
||||
obj.mock.assert_called_once_with((10, 20), 2)
|
||||
self.assertEqual(r1, {10: "fish", 20: "chips"})
|
||||
obj.mock.reset_mock()
|
||||
|
||||
|
169
tests/util/test_batching_queue.py
Normal file
169
tests/util/test_batching_queue.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util.batching_queue import BatchingQueue
|
||||
|
||||
from tests.server import get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class BatchingQueueTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.clock, hs_clock = get_clock()
|
||||
|
||||
self._pending_calls = []
|
||||
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
||||
|
||||
async def _process_queue(self, values):
|
||||
d = defer.Deferred()
|
||||
self._pending_calls.append((values, d))
|
||||
return await make_deferred_yieldable(d)
|
||||
|
||||
def test_simple(self):
|
||||
"""Tests the basic case of calling `add_to_queue` once and having
|
||||
`_process_queue` return.
|
||||
"""
|
||||
|
||||
self.assertFalse(self._pending_calls)
|
||||
|
||||
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
|
||||
|
||||
# The queue should wait a reactor tick before calling the processing
|
||||
# function.
|
||||
self.assertFalse(self._pending_calls)
|
||||
self.assertFalse(queue_d.called)
|
||||
|
||||
# We should see a call to `_process_queue` after a reactor tick.
|
||||
self.clock.pump([0])
|
||||
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo"])
|
||||
self.assertFalse(queue_d.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back.
|
||||
self._pending_calls.pop()[1].callback("bar")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d), "bar")
|
||||
|
||||
def test_batching(self):
|
||||
"""Test that multiple calls at the same time get batched up into one
|
||||
call to `_process_queue`.
|
||||
"""
|
||||
|
||||
self.assertFalse(self._pending_calls)
|
||||
|
||||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||
|
||||
self.clock.pump([0])
|
||||
|
||||
# We should see only *one* call to `_process_queue`
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
|
||||
self.assertFalse(queue_d1.called)
|
||||
self.assertFalse(queue_d2.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to both.
|
||||
self._pending_calls.pop()[1].callback("bar")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d1), "bar")
|
||||
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
||||
|
||||
def test_queuing(self):
|
||||
"""Test that we queue up requests while a `_process_queue` is being
|
||||
called.
|
||||
"""
|
||||
|
||||
self.assertFalse(self._pending_calls)
|
||||
|
||||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||
self.clock.pump([0])
|
||||
|
||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||
|
||||
# We should see only *one* call to `_process_queue`
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||
self.assertFalse(queue_d1.called)
|
||||
self.assertFalse(queue_d2.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# first.
|
||||
self._pending_calls.pop()[1].callback("bar1")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||
self.assertFalse(queue_d2.called)
|
||||
|
||||
# We should now see a second call to `_process_queue`
|
||||
self.clock.pump([0])
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo2"])
|
||||
self.assertFalse(queue_d2.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# second.
|
||||
self._pending_calls.pop()[1].callback("bar2")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||
|
||||
def test_different_keys(self):
|
||||
"""Test that calls to different keys get processed in parallel."""
|
||||
|
||||
self.assertFalse(self._pending_calls)
|
||||
|
||||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
|
||||
self.clock.pump([0])
|
||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
|
||||
self.clock.pump([0])
|
||||
|
||||
# We queue up another item with key=2 to check that we will keep taking
|
||||
# things off the queue.
|
||||
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2))
|
||||
|
||||
# We should see two calls to `_process_queue`
|
||||
self.assertEqual(len(self._pending_calls), 2)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||
self.assertEqual(self._pending_calls[1][0], ["foo2"])
|
||||
self.assertFalse(queue_d1.called)
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# first.
|
||||
self._pending_calls.pop(0)[1].callback("bar1")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# second.
|
||||
self._pending_calls.pop()[1].callback("bar2")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||
self.assertFalse(queue_d3.called)
|
||||
|
||||
# We should now see a call `_pending_calls` for `foo3`
|
||||
self.clock.pump([0])
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo3"])
|
||||
self.assertFalse(queue_d3.called)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# third deferred.
|
||||
self._pending_calls.pop()[1].callback("bar4")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d3), "bar4")
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List
|
||||
from typing import Dict, Iterable, List, Sequence
|
||||
|
||||
from synapse.util.iterutils import chunk_seq, sorted_topologically
|
||||
|
||||
@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
|
||||
)
|
||||
|
||||
def test_empty_input(self):
|
||||
parts = chunk_seq([], 5)
|
||||
parts = chunk_seq([], 5) # type: Iterable[Sequence]
|
||||
|
||||
self.assertEqual(
|
||||
list(parts),
|
||||
|
@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEquals(cache.pop("key"), None)
|
||||
|
||||
def test_del_multi(self):
|
||||
cache = LruCache(4, keylen=2, cache_type=TreeCache)
|
||||
cache = LruCache(4, cache_type=TreeCache)
|
||||
cache[("animal", "cat")] = "mew"
|
||||
cache[("animal", "dog")] = "woof"
|
||||
cache[("vehicles", "car")] = "vroom"
|
||||
@ -165,7 +165,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
m2 = Mock()
|
||||
m3 = Mock()
|
||||
m4 = Mock()
|
||||
cache = LruCache(4, keylen=2, cache_type=TreeCache)
|
||||
cache = LruCache(4, cache_type=TreeCache)
|
||||
|
||||
cache.set(("a", "1"), "value", callbacks=[m1])
|
||||
cache.set(("a", "2"), "value", callbacks=[m2])
|
||||
|
@ -51,10 +51,12 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
self.pump()
|
||||
|
||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||
self.assertEqual(new_timings["failure_ts"], failure_ts)
|
||||
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
|
||||
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
|
||||
self.assertEqual(new_timings.failure_ts, failure_ts)
|
||||
self.assertEqual(new_timings.retry_last_ts, failure_ts)
|
||||
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
|
||||
|
||||
# now if we try again we should get a failure
|
||||
self.get_failure(
|
||||
@ -77,14 +79,16 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
self.pump()
|
||||
|
||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||
self.assertEqual(new_timings["failure_ts"], failure_ts)
|
||||
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
|
||||
self.assertEqual(new_timings.failure_ts, failure_ts)
|
||||
self.assertEqual(new_timings.retry_last_ts, retry_ts)
|
||||
self.assertGreaterEqual(
|
||||
new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
|
||||
new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
|
||||
)
|
||||
self.assertLessEqual(
|
||||
new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
|
||||
new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
|
||||
)
|
||||
|
||||
#
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.util.caches.treecache import TreeCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
|
||||
from .. import unittest
|
||||
|
||||
@ -64,12 +64,14 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
cache[("a", "b")] = "AB"
|
||||
cache[("b", "a")] = "BA"
|
||||
self.assertEquals(cache.get(("a", "a")), "AA")
|
||||
cache.pop(("a",))
|
||||
popped = cache.pop(("a",))
|
||||
self.assertEquals(cache.get(("a", "a")), None)
|
||||
self.assertEquals(cache.get(("a", "b")), None)
|
||||
self.assertEquals(cache.get(("b", "a")), "BA")
|
||||
self.assertEquals(len(cache), 1)
|
||||
|
||||
self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
|
||||
|
||||
def test_clear(self):
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
|
10
tox.ini
10
tox.ini
@ -34,7 +34,17 @@ lint_targets =
|
||||
synapse
|
||||
tests
|
||||
scripts
|
||||
# annoyingly, black doesn't find these so we have to list them
|
||||
scripts/export_signing_key
|
||||
scripts/generate_config
|
||||
scripts/generate_log_config
|
||||
scripts/hash_password
|
||||
scripts/register_new_matrix_user
|
||||
scripts/synapse_port_db
|
||||
scripts-dev
|
||||
scripts-dev/build_debian_packages
|
||||
scripts-dev/sign_json
|
||||
scripts-dev/update_database
|
||||
stubs
|
||||
contrib
|
||||
synctl
|
||||
|
Loading…
x
Reference in New Issue
Block a user