Merge remote-tracking branch 'upstream/release-v1.32.0'

This commit is contained in:
Tulir Asokan 2021-04-13 16:44:48 +03:00
commit dbe12b3eb1
207 changed files with 4883 additions and 1550 deletions

View file

@ -1,16 +1,16 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# this script is run by buildkite in a plain `xenial` container; it installs the # this script is run by buildkite in a plain `bionic` container; it installs the
# minimal requirements for tox and hands over to the py35-old tox environment. # minimal requirements for tox and hands over to the py3-old tox environment.
set -ex set -ex
apt-get update apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8" export LANG="C.UTF-8"
# Prevent virtualenv from auto-updating pip to an incompatible version # Prevent virtualenv from auto-updating pip to an incompatible version
export VIRTUALENV_NO_DOWNLOAD=1 export VIRTUALENV_NO_DOWNLOAD=1
exec tox -e py35-old,combine exec tox -e py3-old,combine

322
.github/workflows/tests.yml vendored Normal file
View file

@ -0,0 +1,322 @@
name: Tests
on:
push:
branches: ["develop", "release-*"]
pull_request:
jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
toxenv:
- "check-sampleconfig"
- "check_codestyle"
- "check_isort"
- "mypy"
- "packaging"
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- run: pip install tox
- run: tox -e ${{ matrix.toxenv }}
lint-crlf:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Check line endings
run: scripts-dev/check_line_terminators.sh
lint-newsfile:
if: ${{ github.base_ref == 'develop' || contains(github.base_ref, 'release-') }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- run: pip install tox
- name: Patch Buildkite-specific test script
run: |
sed -i -e 's/\$BUILDKITE_PULL_REQUEST/${{ github.event.number }}/' \
scripts-dev/check-newsfragment
- run: scripts-dev/check-newsfragment
lint-sdist:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.x"
- run: pip install wheel
- run: python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v2
with:
name: Python Distributions
path: dist/*
# Dummy step to gate other tests on without repeating the whole list
linting-done:
if: ${{ always() }} # Run this even if prior jobs were skipped
needs: [lint, lint-crlf, lint-newsfile, lint-sdist]
runs-on: ubuntu-latest
steps:
- run: "true"
trial:
if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail
needs: linting-done
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9"]
database: ["sqlite"]
include:
# Newest Python without optional deps
- python-version: "3.9"
toxenv: "py-noextras,combine"
# Oldest Python with PostgreSQL
- python-version: "3.6"
database: "postgres"
postgres-version: "9.6"
# Newest Python with PostgreSQL
- python-version: "3.9"
database: "postgres"
postgres-version: "13"
steps:
- uses: actions/checkout@v2
- run: sudo apt-get -qq install xmlsec1
- name: Set up PostgreSQL ${{ matrix.postgres-version }}
if: ${{ matrix.postgres-version }}
run: |
docker run -d -p 5432:5432 \
-e POSTGRES_PASSWORD=postgres \
-e POSTGRES_INITDB_ARGS="--lc-collate C --lc-ctype C --encoding UTF8" \
postgres:${{ matrix.postgres-version }}
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- run: pip install tox
- name: Await PostgreSQL
if: ${{ matrix.postgres-version }}
timeout-minutes: 2
run: until pg_isready -h localhost; do sleep 1; done
- run: tox -e py,combine
env:
TRIAL_FLAGS: "--jobs=2"
SYNAPSE_POSTGRES: ${{ matrix.database == 'postgres' || '' }}
SYNAPSE_POSTGRES_HOST: localhost
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
- name: Dump logs
# Note: Dumps to workflow logs instead of using actions/upload-artifact
# This keeps logs colocated with failing jobs
# It also ignores find's exit code; this is a best effort affair
run: >-
find _trial_temp -name '*.log'
-exec echo "::group::{}" \;
-exec cat {} \;
-exec echo "::endgroup::" \;
|| true
trial-olddeps:
if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail
needs: linting-done
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Test with old deps
uses: docker://ubuntu:bionic # For old python and sqlite
with:
workdir: /github/workspace
entrypoint: .buildkite/scripts/test_old_deps.sh
env:
TRIAL_FLAGS: "--jobs=2"
- name: Dump logs
# Note: Dumps to workflow logs instead of using actions/upload-artifact
# This keeps logs colocated with failing jobs
# It also ignores find's exit code; this is a best effort affair
run: >-
find _trial_temp -name '*.log'
-exec echo "::group::{}" \;
-exec cat {} \;
-exec echo "::endgroup::" \;
|| true
trial-pypy:
# Very slow; only run if the branch name includes 'pypy'
if: ${{ contains(github.ref, 'pypy') && !failure() }}
needs: linting-done
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["pypy-3.6"]
steps:
- uses: actions/checkout@v2
- run: sudo apt-get -qq install xmlsec1 libxml2-dev libxslt-dev
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- run: pip install tox
- run: tox -e py,combine
env:
TRIAL_FLAGS: "--jobs=2"
- name: Dump logs
# Note: Dumps to workflow logs instead of using actions/upload-artifact
# This keeps logs colocated with failing jobs
# It also ignores find's exit code; this is a best effort affair
run: >-
find _trial_temp -name '*.log'
-exec echo "::group::{}" \;
-exec cat {} \;
-exec echo "::endgroup::" \;
|| true
sytest:
if: ${{ !failure() }}
needs: linting-done
runs-on: ubuntu-latest
container:
image: matrixdotorg/sytest-synapse:${{ matrix.sytest-tag }}
volumes:
- ${{ github.workspace }}:/src
env:
BUILDKITE_BRANCH: ${{ github.head_ref }}
POSTGRES: ${{ matrix.postgres && 1}}
MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}}
WORKERS: ${{ matrix.workers && 1 }}
REDIS: ${{ matrix.redis && 1 }}
BLACKLIST: ${{ matrix.workers && 'synapse-blacklist-with-workers' }}
strategy:
fail-fast: false
matrix:
include:
- sytest-tag: bionic
- sytest-tag: bionic
postgres: postgres
- sytest-tag: testing
postgres: postgres
- sytest-tag: bionic
postgres: multi-postgres
workers: workers
- sytest-tag: buster
postgres: multi-postgres
workers: workers
- sytest-tag: buster
postgres: postgres
workers: workers
redis: redis
steps:
- uses: actions/checkout@v2
- name: Prepare test blacklist
run: cat sytest-blacklist .buildkite/worker-blacklist > synapse-blacklist-with-workers
- name: Run SyTest
run: /bootstrap.sh synapse
working-directory: /src
- name: Dump results.tap
if: ${{ always() }}
run: cat /logs/results.tap
- name: Upload SyTest logs
uses: actions/upload-artifact@v2
if: ${{ always() }}
with:
name: Sytest Logs - ${{ job.status }} - (${{ join(matrix.*, ', ') }})
path: |
/logs/results.tap
/logs/**/*.log*
portdb:
if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail
needs: linting-done
runs-on: ubuntu-latest
strategy:
matrix:
include:
- python-version: "3.6"
postgres-version: "9.6"
- python-version: "3.9"
postgres-version: "13"
services:
postgres:
image: postgres:${{ matrix.postgres-version }}
ports:
- 5432:5432
env:
POSTGRES_PASSWORD: "postgres"
POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v2
- run: sudo apt-get -qq install xmlsec1
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Patch Buildkite-specific test scripts
run: |
sed -i -e 's/host="postgres"/host="localhost"/' .buildkite/scripts/create_postgres_db.py
sed -i -e 's/host: postgres/host: localhost/' .buildkite/postgres-config.yaml
sed -i -e 's|/src/||' .buildkite/{sqlite,postgres}-config.yaml
sed -i -e 's/\$TOP/\$GITHUB_WORKSPACE/' .coveragerc
- run: .buildkite/scripts/test_synapse_port_db.sh
complement:
if: ${{ !failure() }}
needs: linting-done
runs-on: ubuntu-latest
container:
# https://github.com/matrix-org/complement/blob/master/dockerfiles/ComplementCIBuildkite.Dockerfile
image: matrixdotorg/complement:latest
env:
CI: true
ports:
- 8448:8448
volumes:
- /var/run/docker.sock:/var/run/docker.sock
steps:
- name: Run actions/checkout@v2 for synapse
uses: actions/checkout@v2
with:
path: synapse
- name: Run actions/checkout@v2 for complement
uses: actions/checkout@v2
with:
repository: "matrix-org/complement"
path: complement
# Build initial Synapse image
- run: docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
working-directory: synapse
# Build a ready-to-run Synapse image based on the initial image above.
# This new image includes a config file, keys for signing and TLS, and
# other settings to make it suitable for testing under Complement.
- run: docker build -t complement-synapse -f Synapse.Dockerfile .
working-directory: complement/dockerfiles
# Run Complement
- run: go test -v -tags synapse_blacklist ./tests
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement

View file

@ -1,3 +1,75 @@
Synapse 1.32.0rc1 (2021-04-13)
==============================
**Note:** This release requires Python 3.6+ and Postgres 9.6+ or SQLite 3.22+.
This release removes the deprecated `GET /_synapse/admin/v1/users/<user_id>` admin API. Please use the [v2 API](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/user_admin_api.rst#query-user-account) instead, which has improved capabilities.
This release requires Application Services to use type `m.login.application_services` when registering users via the `/_matrix/client/r0/register` endpoint to comply with the spec. Please ensure your Application Services are up to date.
Features
--------
- Add a Synapse module for routing presence updates between users. ([\#9491](https://github.com/matrix-org/synapse/issues/9491))
- Add an admin API to manage ratelimit for a specific user. ([\#9648](https://github.com/matrix-org/synapse/issues/9648))
- Include request information in structured logging output. ([\#9654](https://github.com/matrix-org/synapse/issues/9654))
- Add `order_by` to the admin API `GET /_synapse/admin/v2/users`. Contributed by @dklimpel. ([\#9691](https://github.com/matrix-org/synapse/issues/9691))
- Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`. ([\#9700](https://github.com/matrix-org/synapse/issues/9700))
- Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. ([\#9717](https://github.com/matrix-org/synapse/issues/9717), [\#9735](https://github.com/matrix-org/synapse/issues/9735))
- Update experimental support for Spaces: include `m.room.create` in the room state sent with room-invites. ([\#9710](https://github.com/matrix-org/synapse/issues/9710))
- Synapse now requires Python 3.6 or later. It also requires Postgres 9.6 or later or SQLite 3.22 or later. ([\#9766](https://github.com/matrix-org/synapse/issues/9766))
Bugfixes
--------
- Prevent `synapse_forward_extremities` and `synapse_excess_extremity_events` Prometheus metrics from initially reporting zero-values after startup. ([\#8926](https://github.com/matrix-org/synapse/issues/8926))
- Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. ([\#9711](https://github.com/matrix-org/synapse/issues/9711))
- Fix longstanding bug which caused `duplicate key value violates unique constraint "remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"` errors. ([\#9725](https://github.com/matrix-org/synapse/issues/9725))
- Fix bug where sharded federation senders could get stuck repeatedly querying the DB in a loop, using lots of CPU. ([\#9770](https://github.com/matrix-org/synapse/issues/9770))
- Fix duplicate logging of exceptions thrown during federation transaction processing. ([\#9780](https://github.com/matrix-org/synapse/issues/9780))
Updates to the Docker image
---------------------------
- Move opencontainers labels to the final Docker image such that users can inspect them. ([\#9765](https://github.com/matrix-org/synapse/issues/9765))
Improved Documentation
----------------------
- Make the `allowed_local_3pids` regex example in the sample config stricter. ([\#9719](https://github.com/matrix-org/synapse/issues/9719))
Deprecations and Removals
-------------------------
- Remove old admin API `GET /_synapse/admin/v1/users/<user_id>`. ([\#9401](https://github.com/matrix-org/synapse/issues/9401))
- Make `/_matrix/client/r0/register` expect a type of `m.login.application_service` when an Application Service registers a user, to align with [the relevant spec](https://spec.matrix.org/unstable/application-service-api/#server-admin-style-permissions). ([\#9548](https://github.com/matrix-org/synapse/issues/9548))
Internal Changes
----------------
- Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz. ([\#9718](https://github.com/matrix-org/synapse/issues/9718))
- Experiment with GitHub Actions for CI. ([\#9661](https://github.com/matrix-org/synapse/issues/9661))
- Introduce flake8-bugbear to the test suite and fix some of its lint violations. ([\#9682](https://github.com/matrix-org/synapse/issues/9682))
- Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist. ([\#9685](https://github.com/matrix-org/synapse/issues/9685))
- Improve Jaeger tracing for `to_device` messages. ([\#9686](https://github.com/matrix-org/synapse/issues/9686))
- Add release helper script for automating part of the Synapse release process. ([\#9713](https://github.com/matrix-org/synapse/issues/9713))
- Add type hints to expiring cache. ([\#9730](https://github.com/matrix-org/synapse/issues/9730))
- Convert various testcases to `HomeserverTestCase`. ([\#9736](https://github.com/matrix-org/synapse/issues/9736))
- Start linting mypy with `no_implicit_optional`. ([\#9742](https://github.com/matrix-org/synapse/issues/9742))
- Add missing type hints to federation handler and server. ([\#9743](https://github.com/matrix-org/synapse/issues/9743))
- Check that a `ConfigError` is raised, rather than simply `Exception`, when appropriate in homeserver config file generation tests. ([\#9753](https://github.com/matrix-org/synapse/issues/9753))
- Fix incompatibility with `tox` 2.5. ([\#9769](https://github.com/matrix-org/synapse/issues/9769))
- Enable Complement tests for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): Spaces Summary API. ([\#9771](https://github.com/matrix-org/synapse/issues/9771))
- Use mock from the standard library instead of a separate package. ([\#9772](https://github.com/matrix-org/synapse/issues/9772))
- Update Black configuration to target Python 3.6. ([\#9781](https://github.com/matrix-org/synapse/issues/9781))
- Add option to skip unit tests when building Debian packages. ([\#9793](https://github.com/matrix-org/synapse/issues/9793))
Synapse 1.31.0 (2021-04-06) Synapse 1.31.0 (2021-04-06)
=========================== ===========================

View file

@ -393,7 +393,12 @@ massive excess of outgoing federation requests (see `discussion
indicate that your server is also issuing far more outgoing federation indicate that your server is also issuing far more outgoing federation
requests than can be accounted for by your users' activity, this is a requests than can be accounted for by your users' activity, this is a
likely cause. The misbehavior can be worked around by setting likely cause. The misbehavior can be worked around by setting
``use_presence: false`` in the Synapse config file. the following in the Synapse config file:
.. code-block:: yaml
presence:
enabled: false
People can't accept room invitations from me People can't accept room invitations from me
-------------------------------------------- --------------------------------------------

View file

@ -85,6 +85,19 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.32.0
====================
Removal of old List Accounts Admin API
--------------------------------------
The deprecated v1 "list accounts" admin API (``GET /_synapse/admin/v1/users/<user_id>``) has been removed in this version.
The `v2 list accounts API <https://github.com/matrix-org/synapse/blob/master/docs/admin_api/user_admin_api.rst#list-accounts>`_
has been available since Synapse 1.7.0 (2019-12-13), and is accessible under ``GET /_synapse/admin/v2/users``.
The deprecation of the old endpoint was announced with Synapse 1.28.0 (released on 2021-02-25).
Upgrading to v1.29.0 Upgrading to v1.29.0
==================== ====================

View file

@ -24,6 +24,7 @@ import sys
import time import time
import urllib import urllib
from http import TwistedHttpClient from http import TwistedHttpClient
from typing import Optional
import nacl.encoding import nacl.encoding
import nacl.signing import nacl.signing
@ -718,7 +719,7 @@ class SynapseCmd(cmd.Cmd):
method, method,
path, path,
data=None, data=None,
query_params={"access_token": None}, query_params: Optional[dict] = None,
alt_text=None, alt_text=None,
): ):
"""Runs an HTTP request and pretty prints the output. """Runs an HTTP request and pretty prints the output.
@ -729,6 +730,8 @@ class SynapseCmd(cmd.Cmd):
data: Raw JSON data if any data: Raw JSON data if any
query_params: dict of query parameters to add to the url query_params: dict of query parameters to add to the url
""" """
query_params = query_params or {"access_token": None}
url = self._url() + path url = self._url() + path
if "access_token" in query_params: if "access_token" in query_params:
query_params["access_token"] = self._tok() query_params["access_token"] = self._tok()

View file

@ -16,6 +16,7 @@
import json import json
import urllib import urllib
from pprint import pformat from pprint import pformat
from typing import Optional
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody from twisted.web.client import Agent, readBody
@ -85,8 +86,9 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response) body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
def _create_put_request(self, url, json_data, headers_dict={}): def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a PUT request""" """Wrapper of _create_request to issue a PUT request"""
headers_dict = headers_dict or {}
if "Content-Type" not in headers_dict: if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@ -95,14 +97,22 @@ class TwistedHttpClient(HttpClient):
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
) )
def _create_get_request(self, url, headers_dict={}): def _create_get_request(self, url, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a GET request""" """Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict) return self._create_request("GET", url, headers_dict=headers_dict or {})
@defer.inlineCallbacks @defer.inlineCallbacks
def do_request( def do_request(
self, method, url, data=None, qparams=None, jsonreq=True, headers={} self,
method,
url,
data=None,
qparams=None,
jsonreq=True,
headers: Optional[dict] = None,
): ):
headers = headers or {}
if qparams: if qparams:
url = "%s?%s" % (url, urllib.urlencode(qparams, True)) url = "%s?%s" % (url, urllib.urlencode(qparams, True))
@ -123,8 +133,12 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, method, url, producer=None, headers_dict={}): def _create_request(
self, method, url, producer=None, headers_dict: Optional[dict] = None
):
"""Creates and sends a request to the given url""" """Creates and sends a request to the given url"""
headers_dict = headers_dict or {}
headers_dict["User-Agent"] = ["Synapse Cmd Client"] headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5 retries_left = 5

View file

@ -50,7 +50,13 @@ PACKAGE_BUILD_DIR="debian/matrix-synapse-py3"
VIRTUALENV_DIR="${PACKAGE_BUILD_DIR}${DH_VIRTUALENV_INSTALL_ROOT}/matrix-synapse" VIRTUALENV_DIR="${PACKAGE_BUILD_DIR}${DH_VIRTUALENV_INSTALL_ROOT}/matrix-synapse"
TARGET_PYTHON="${VIRTUALENV_DIR}/bin/python" TARGET_PYTHON="${VIRTUALENV_DIR}/bin/python"
# we copy the tests to a temporary directory so that we can put them on the case "$DEB_BUILD_OPTIONS" in
*nocheck*)
# Skip running tests if "nocheck" present in $DEB_BUILD_OPTIONS
;;
*)
# Copy tests to a temporary directory so that we can put them on the
# PYTHONPATH without putting the uninstalled synapse on the pythonpath. # PYTHONPATH without putting the uninstalled synapse on the pythonpath.
tmpdir=`mktemp -d` tmpdir=`mktemp -d`
trap "rm -r $tmpdir" EXIT trap "rm -r $tmpdir" EXIT
@ -60,6 +66,9 @@ cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \ PYTHONPATH="$tmpdir" \
"${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests "${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
;;
esac
# build the config file # build the config file
"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \ "${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
--config-dir="/etc/matrix-synapse" \ --config-dir="/etc/matrix-synapse" \

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.31.0+nmu1) UNRELEASED; urgency=medium
* Skip tests when DEB_BUILD_OPTIONS contains "nocheck".
-- Dan Callahan <danc@element.io> Mon, 12 Apr 2021 13:07:36 +0000
matrix-synapse-py3 (1.31.0) stable; urgency=medium matrix-synapse-py3 (1.31.0) stable; urgency=medium
* New synapse release 1.31.0. * New synapse release 1.31.0.

View file

@ -18,11 +18,6 @@ ARG PYTHON_VERSION=3.8
### ###
FROM docker.io/python:${PYTHON_VERSION}-slim as builder FROM docker.io/python:${PYTHON_VERSION}-slim as builder
LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
LABEL org.opencontainers.image.licenses='Apache-2.0'
# install the OS build deps # install the OS build deps
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
build-essential \ build-essential \
@ -66,6 +61,11 @@ RUN pip install --prefix="/install" --no-deps --no-warn-script-location /synapse
FROM docker.io/python:${PYTHON_VERSION}-slim FROM docker.io/python:${PYTHON_VERSION}-slim
LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
LABEL org.opencontainers.image.licenses='Apache-2.0'
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
curl \ curl \
gosu \ gosu \

View file

@ -173,18 +173,10 @@ report_stats: False
## API Configuration ## ## API Configuration ##
room_invite_state_types:
- "m.room.join_rules"
- "m.room.canonical_alias"
- "m.room.avatar"
- "m.room.name"
{% if SYNAPSE_APPSERVICES %} {% if SYNAPSE_APPSERVICES %}
app_service_config_files: app_service_config_files:
{% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}" {% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}"
{% endfor %} {% endfor %}
{% else %}
app_service_config_files: []
{% endif %} {% endif %}
macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}" macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"

View file

@ -111,35 +111,16 @@ List Accounts
============= =============
This API returns all local user accounts. This API returns all local user accounts.
By default, the response is ordered by ascending user ID.
The api is:: The API is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
To use it, you will need to authenticate by providing an ``access_token`` for a To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_. server admin: see `README.rst <README.rst>`_.
The parameter ``from`` is optional but used for pagination, denoting the A response body like the following is returned:
offset in the returned results. This should be treated as an opaque value and
not explicitly set to anything other than the return value of ``next_token``
from a previous call.
The parameter ``limit`` is optional but is used for pagination, denoting the
maximum number of items to return in this call. Defaults to ``100``.
The parameter ``user_id`` is optional and filters to only return users with user IDs
that contain this value. This parameter is ignored when using the ``name`` parameter.
The parameter ``name`` is optional and filters to only return users with user ID localparts
**or** displaynames that contain this value.
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
Defaults to ``true`` to include guest users.
The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users.
Defaults to ``false`` to exclude deactivated users.
A JSON body is returned with the following shape:
.. code:: json .. code:: json
@ -175,6 +156,66 @@ with ``from`` set to the value of ``next_token``. This will return a new page.
If the endpoint does not return a ``next_token`` then there are no more users If the endpoint does not return a ``next_token`` then there are no more users
to paginate through. to paginate through.
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - Is optional and filters to only return users with user IDs
that contain this value. This parameter is ignored when using the ``name`` parameter.
- ``name`` - Is optional and filters to only return users with user ID localparts
**or** displaynames that contain this value.
- ``guests`` - string representing a bool - Is optional and if ``false`` will **exclude** guest users.
Defaults to ``true`` to include guest users.
- ``deactivated`` - string representing a bool - Is optional and if ``true`` will **include** deactivated users.
Defaults to ``false`` to exclude deactivated users.
- ``limit`` - string representing a positive integer - Is optional but is used for pagination,
denoting the maximum number of items to return in this call. Defaults to ``100``.
- ``from`` - string representing a positive integer - Is optional but used for pagination,
denoting the offset in the returned results. This should be treated as an opaque value and
not explicitly set to anything other than the return value of ``next_token`` from a previous call.
Defaults to ``0``.
- ``order_by`` - The method by which to sort the returned list of users.
If the ordered field has duplicates, the second order is always by ascending ``name``,
which guarantees a stable ordering. Valid values are:
- ``name`` - Users are ordered alphabetically by ``name``. This is the default.
- ``is_guest`` - Users are ordered by ``is_guest`` status.
- ``admin`` - Users are ordered by ``admin`` status.
- ``user_type`` - Users are ordered alphabetically by ``user_type``.
- ``deactivated`` - Users are ordered by ``deactivated`` status.
- ``shadow_banned`` - Users are ordered by ``shadow_banned`` status.
- ``displayname`` - Users are ordered alphabetically by ``displayname``.
- ``avatar_url`` - Users are ordered alphabetically by avatar URL.
- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
Caution. The database only has indexes on the columns ``name`` and ``created_ts``.
This means that if a different sort order is used (``is_guest``, ``admin``,
``user_type``, ``deactivated``, ``shadow_banned``, ``avatar_url`` or ``displayname``),
this can cause a large load on the database, especially for large environments.
**Response**
The following fields are returned in the JSON response body:
- ``users`` - An array of objects, each containing information about an user.
User objects contain the following fields:
- ``name`` - string - Fully-qualified user ID (ex. ``@user:server.com``).
- ``is_guest`` - bool - Status if that user is a guest account.
- ``admin`` - bool - Status if that user is a server administrator.
- ``user_type`` - string - Type of the user. Normal users are type ``None``.
This allows user type specific behaviour. There are also types ``support`` and ``bot``.
- ``deactivated`` - bool - Status if that user has been marked as deactivated.
- ``shadow_banned`` - bool - Status if that user has been marked as shadow banned.
- ``displayname`` - string - The user's display name if they have set one.
- ``avatar_url`` - string - The user's avatar URL if they have set one.
- ``next_token``: string representing a positive integer - Indication for pagination. See above.
- ``total`` - integer - Total number of media.
Query current sessions for a user Query current sessions for a user
================================= =================================
@ -823,3 +864,118 @@ The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must - ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local. be local.
Override ratelimiting for users
===============================
This API allows to override or disable ratelimiting for a specific user.
There are specific APIs to set, get and delete a ratelimit.
Get status of ratelimit
-----------------------
The API is::
GET /_synapse/admin/v1/users/<user_id>/override_ratelimit
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
A response body like the following is returned:
.. code:: json
{
"messages_per_second": 0,
"burst_count": 0
}
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local.
**Response**
The following fields are returned in the JSON response body:
- ``messages_per_second`` - integer - The number of actions that can
be performed in a second. `0` mean that ratelimiting is disabled for this user.
- ``burst_count`` - integer - How many actions that can be performed before
being limited.
If **no** custom ratelimit is set, an empty JSON dict is returned.
.. code:: json
{}
Set ratelimit
-------------
The API is::
POST /_synapse/admin/v1/users/<user_id>/override_ratelimit
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
A response body like the following is returned:
.. code:: json
{
"messages_per_second": 0,
"burst_count": 0
}
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local.
Body parameters:
- ``messages_per_second`` - positive integer, optional. The number of actions that can
be performed in a second. Defaults to ``0``.
- ``burst_count`` - positive integer, optional. How many actions that can be performed
before being limited. Defaults to ``0``.
To disable users' ratelimit set both values to ``0``.
**Response**
The following fields are returned in the JSON response body:
- ``messages_per_second`` - integer - The number of actions that can
be performed in a second.
- ``burst_count`` - integer - How many actions that can be performed before
being limited.
Delete ratelimit
----------------
The API is::
DELETE /_synapse/admin/v1/users/<user_id>/override_ratelimit
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
An empty JSON dict is returned.
.. code:: json
{}
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local.

View file

@ -128,6 +128,9 @@ Some guidelines follow:
will be if no sub-options are enabled). will be if no sub-options are enabled).
- Lines should be wrapped at 80 characters. - Lines should be wrapped at 80 characters.
- Use two-space indents. - Use two-space indents.
- `true` and `false` are spelt thus (as opposed to `True`, etc.)
- Use single quotes (`'`) rather than double-quotes (`"`) or backticks
(`` ` ``) to refer to configuration options.
Example: Example:

View file

@ -0,0 +1,235 @@
# Presence Router Module
Synapse supports configuring a module that can specify additional users
(local or remote) to should receive certain presence updates from local
users.
Note that routing presence via Application Service transactions is not
currently supported.
The presence routing module is implemented as a Python class, which will
be imported by the running Synapse.
## Python Presence Router Class
The Python class is instantiated with two objects:
* A configuration object of some type (see below).
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods related to presence routing.
Note that one method of `ModuleApi` that may be useful is:
```python
async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None
```
which can be given a list of local or remote MXIDs to broadcast known, online user
presence to (for those users that the receiving user is considered interested in).
It does not include state for users who are currently offline, and it can only be
called on workers that support sending federation.
### Module structure
Below is a list of possible methods that can be implemented, and whether they are
required.
#### `parse_config`
```python
def parse_config(config_dict: dict) -> Any
```
**Required.** A static method that is passed a dictionary of config options, and
should return a validated config object. This method is described further in
[Configuration](#configuration).
#### `get_users_for_states`
```python
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
```
**Required.** An asynchronous method that is passed an iterable of user presence
state. This method can determine whether a given presence update should be sent to certain
users. It does this by returning a dictionary with keys representing local or remote
Matrix User IDs, and values being a python set
of `synapse.handlers.presence.UserPresenceState` instances.
Synapse will then attempt to send the specified presence updates to each user when
possible.
#### `get_interested_users`
```python
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]
```
**Required.** An asynchronous method that is passed a single Matrix User ID. This
method is expected to return the users that the passed in user may be interested in the
presence of. Returned users may be local or remote. The presence routed as a result of
what this method returns is sent in addition to the updates already sent between users
that share a room together. Presence updates are deduplicated.
This method should return a python set of Matrix User IDs, or the object
`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed
user should receive presence information for *all* known users.
For clarity, if the user `@alice:example.org` is passed to this method, and the Set
`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice
should receive presence updates sent by Bob and Charlie, regardless of whether these
users share a room.
### Example
Below is an example implementation of a presence router class.
```python
from typing import Dict, Iterable, Set, Union
from synapse.events.presence_router import PresenceRouter
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
class PresenceRouterConfig:
def __init__(self):
# Config options with their defaults
# A list of users to always send all user presence updates to
self.always_send_to_users = [] # type: List[str]
# A list of users to ignore presence updates for. Does not affect
# shared-room presence relationships
self.blacklisted_users = [] # type: List[str]
class ExamplePresenceRouter:
"""An example implementation of synapse.presence_router.PresenceRouter.
Supports routing all presence to a configured set of users, or a subset
of presence from certain users to members of certain rooms.
Args:
config: A configuration object.
module_api: An instance of Synapse's ModuleApi.
"""
def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi):
self._config = config
self._module_api = module_api
@staticmethod
def parse_config(config_dict: dict) -> PresenceRouterConfig:
"""Parse a configuration dictionary from the homeserver config, do
some validation and return a typed PresenceRouterConfig.
Args:
config_dict: The configuration dictionary.
Returns:
A validated config object.
"""
# Initialise a typed config object
config = PresenceRouterConfig()
always_send_to_users = config_dict.get("always_send_to_users")
blacklisted_users = config_dict.get("blacklisted_users")
# Do some validation of config options... otherwise raise a
# synapse.config.ConfigError.
config.always_send_to_users = always_send_to_users
config.blacklisted_users = blacklisted_users
return config
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
"""Given an iterable of user presence updates, determine where each one
needs to go. Returned results will not affect presence updates that are
sent between users who share a room.
Args:
state_updates: An iterable of user presence state updates.
Returns:
A dictionary of user_id -> set of UserPresenceState that the user should
receive.
"""
destination_users = {} # type: Dict[str, Set[UserPresenceState]
# Ignore any updates for blacklisted users
desired_updates = set()
for update in state_updates:
if update.state_key not in self._config.blacklisted_users:
desired_updates.add(update)
# Send all presence updates to specific users
for user_id in self._config.always_send_to_users:
destination_users[user_id] = desired_updates
return destination_users
async def get_interested_users(
self,
user_id: str,
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
that this user should receive all incoming local and remote presence updates.
Note that this method will only be called for local users.
Args:
user_id: A user requesting presence updates.
Returns:
A set of user IDs to return additional presence updates for, or
PresenceRouter.ALL_USERS to return presence updates for all other users.
"""
if user_id in self._config.always_send_to_users:
return PresenceRouter.ALL_USERS
return set()
```
#### A note on `get_users_for_states` and `get_interested_users`
Both of these methods are effectively two different sides of the same coin. The logic
regarding which users should receive updates for other users should be the same
between them.
`get_users_for_states` is called when presence updates come in from either federation
or local users, and is used to either direct local presence to remote users, or to
wake up the sync streams of local users to collect remote presence.
In contrast, `get_interested_users` is used to determine the users that presence should
be fetched for when a local user is syncing. This presence is then retrieved, before
being fed through `get_users_for_states` once again, with only the syncing user's
routing information pulled from the resulting dictionary.
Their routing logic should thus line up, else you may run into unintended behaviour.
## Configuration
Once you've crafted your module and installed it into the same Python environment as
Synapse, amend your homeserver config file with the following.
```yaml
presence:
routing_module:
module: my_module.ExamplePresenceRouter
config:
# Any configuration options for your module. The below is an example.
# of setting options for ExamplePresenceRouter.
always_send_to_users: ["@presence_gobbler:example.org"]
blacklisted_users:
- "@alice:example.com"
- "@bob:example.com"
...
```
The contents of `config` will be passed as a Python dictionary to the static
`parse_config` method of your class. The object returned by this method will
then be passed to the `__init__` method of your module as `config`.

View file

@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid
# #
#soft_file_limit: 0 #soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver. # Presence tracking allows users to see the state (e.g online/offline)
# of other local and remote users.
# #
#use_presence: false presence:
# Uncomment to disable presence tracking on this homeserver. This option
# replaces the previous top-level 'use_presence' option.
#
#enabled: false
# Presence routers are third-party modules that can specify additional logic
# to where presence updates from users are routed.
#
presence_router:
# The custom module's class. Uncomment to use a custom presence router module.
#
#module: "my_custom_router.PresenceRouter"
# Configuration options of the custom module. Refer to your module's
# documentation for available options.
#
#config:
# example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars, # Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to # display names) of other users through the client API. Defaults to
@ -1246,9 +1265,9 @@ account_validity:
# #
#allowed_local_3pids: #allowed_local_3pids:
# - medium: email # - medium: email
# pattern: '.*@matrix\.org' # pattern: '^[^@]+@matrix\.org$'
# - medium: email # - medium: email
# pattern: '.*@vector\.im' # pattern: '^[^@]+@vector\.im$'
# - medium: msisdn # - medium: msisdn
# pattern: '\+44' # pattern: '\+44'
@ -1451,14 +1470,31 @@ metrics_flags:
## API Configuration ## ## API Configuration ##
# A list of event types that will be included in the room_invite_state # Controls for the state that is shared with users who receive an invite
# to a room
# #
#room_invite_state_types: room_prejoin_state:
# - "m.room.join_rules" # By default, the following state event types are shared with users who
# - "m.room.canonical_alias" # receive invites to the room:
# - "m.room.avatar" #
# - "m.room.encryption" # - m.room.join_rules
# - "m.room.name" # - m.room.canonical_alias
# - m.room.avatar
# - m.room.encryption
# - m.room.name
#
# Uncomment the following to disable these defaults (so that only the event
# types listed in 'additional_event_types' are shared). Defaults to 'false'.
#
#disable_default_event_types: true
# Additional state event types to share with users when they are invited
# to a room.
#
# By default, this list is empty (so only the default event types are shared).
#
#additional_event_types:
# - org.example.custom.event.type
# A list of application service config files to use # A list of application service config files to use

View file

@ -8,6 +8,7 @@ show_traceback = True
mypy_path = stubs mypy_path = stubs
warn_unreachable = True warn_unreachable = True
local_partial_types = True local_partial_types = True
no_implicit_optional = True
# To find all folders that pass mypy you run: # To find all folders that pass mypy you run:
# #

View file

@ -35,7 +35,7 @@
showcontent = true showcontent = true
[tool.black] [tool.black]
target-version = ['py35'] target-version = ['py36']
exclude = ''' exclude = '''
( (

View file

@ -18,11 +18,9 @@ import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
DISTS = ( DISTS = (
"debian:stretch",
"debian:buster", "debian:buster",
"debian:bullseye", "debian:bullseye",
"debian:sid", "debian:sid",
"ubuntu:xenial",
"ubuntu:bionic", "ubuntu:bionic",
"ubuntu:focal", "ubuntu:focal",
"ubuntu:groovy", "ubuntu:groovy",
@ -43,7 +41,7 @@ class Builder(object):
self._lock = threading.Lock() self._lock = threading.Lock()
self._failed = False self._failed = False
def run_build(self, dist): def run_build(self, dist, skip_tests=False):
"""Build deb for a single distribution""" """Build deb for a single distribution"""
if self._failed: if self._failed:
@ -51,13 +49,13 @@ class Builder(object):
raise Exception("failed") raise Exception("failed")
try: try:
self._inner_build(dist) self._inner_build(dist, skip_tests)
except Exception as e: except Exception as e:
print("build of %s failed: %s" % (dist, e), file=sys.stderr) print("build of %s failed: %s" % (dist, e), file=sys.stderr)
self._failed = True self._failed = True
raise raise
def _inner_build(self, dist): def _inner_build(self, dist, skip_tests=False):
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
os.chdir(projdir) os.chdir(projdir)
@ -101,6 +99,7 @@ class Builder(object):
"--volume=" + debsdir + ":/debs", "--volume=" + debsdir + ":/debs",
"-e", "TARGET_USERID=%i" % (os.getuid(), ), "-e", "TARGET_USERID=%i" % (os.getuid(), ),
"-e", "TARGET_GROUPID=%i" % (os.getgid(), ), "-e", "TARGET_GROUPID=%i" % (os.getgid(), ),
"-e", "DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""),
"dh-venv-builder:" + tag, "dh-venv-builder:" + tag,
], stdout=stdout, stderr=subprocess.STDOUT) ], stdout=stdout, stderr=subprocess.STDOUT)
@ -124,7 +123,7 @@ class Builder(object):
self.active_containers.remove(c) self.active_containers.remove(c)
def run_builds(dists, jobs=1): def run_builds(dists, jobs=1, skip_tests=False):
builder = Builder(redirect_stdout=(jobs > 1)) builder = Builder(redirect_stdout=(jobs > 1))
def sig(signum, _frame): def sig(signum, _frame):
@ -133,7 +132,7 @@ def run_builds(dists, jobs=1):
signal.signal(signal.SIGINT, sig) signal.signal(signal.SIGINT, sig)
with ThreadPoolExecutor(max_workers=jobs) as e: with ThreadPoolExecutor(max_workers=jobs) as e:
res = e.map(builder.run_build, dists) res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists)
# make sure we consume the iterable so that exceptions are raised. # make sure we consume the iterable so that exceptions are raised.
for r in res: for r in res:
@ -148,9 +147,13 @@ if __name__ == '__main__':
'-j', '--jobs', type=int, default=1, '-j', '--jobs', type=int, default=1,
help='specify the number of builds to run in parallel', help='specify the number of builds to run in parallel',
) )
parser.add_argument(
'--no-check', action='store_true',
help='skip running tests after building',
)
parser.add_argument( parser.add_argument(
'dist', nargs='*', default=DISTS, 'dist', nargs='*', default=DISTS,
help='a list of distributions to build for. Default: %(default)s', help='a list of distributions to build for. Default: %(default)s',
) )
args = parser.parse_args() args = parser.parse_args()
run_builds(dists=args.dist, jobs=args.jobs) run_builds(dists=args.dist, jobs=args.jobs, skip_tests=args.no_check)

View file

@ -1,22 +1,49 @@
#! /bin/bash -eu #!/usr/bin/env bash
# This script is designed for developers who want to test their code # This script is designed for developers who want to test their code
# against Complement. # against Complement.
# #
# It makes a Synapse image which represents the current checkout, # It makes a Synapse image which represents the current checkout,
# then downloads Complement and runs it with that image. # builds a synapse-complement image on top, then runs tests with it.
#
# By default the script will fetch the latest Complement master branch and
# run tests with that. This can be overridden to use a custom Complement
# checkout by setting the COMPLEMENT_DIR environment variable to the
# filepath of a local Complement checkout.
#
# A regular expression of test method names can be supplied as the first
# argument to the script. Complement will then only run those tests. If
# no regex is supplied, all tests are run. For example;
#
# ./complement.sh "TestOutboundFederation(Profile|Send)"
#
# Exit if a line returns a non-zero exit code
set -e
# Change to the repository root
cd "$(dirname $0)/.." cd "$(dirname $0)/.."
# Build the base Synapse image from the local checkout # Check for a user-specified Complement checkout
docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile . if [[ -z "$COMPLEMENT_DIR" ]]; then
echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
# Download Complement wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
tar -xzf master.tar.gz tar -xzf master.tar.gz
cd complement-master COMPLEMENT_DIR=complement-master
echo "Checkout available at 'complement-master'"
fi
# Build the Synapse image from Complement, based on the above image we just built # Build the base Synapse image from the local checkout
docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles docker build -t matrixdotorg/synapse -f docker/Dockerfile .
# Build the Synapse monolith image from Complement, based on the above image we just built
docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
# Run the tests on the resulting image! cd "$COMPLEMENT_DIR"
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests
EXTRA_COMPLEMENT_ARGS=""
if [[ -n "$1" ]]; then
# A test name regex has been set, supply it to Complement
EXTRA_COMPLEMENT_ARGS+="-run $1 "
fi
# Run the tests!
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist,msc2946,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests

244
scripts-dev/release.py Executable file
View file

@ -0,0 +1,244 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An interactive script for doing a release. See `run()` below.
"""
import subprocess
import sys
from typing import Optional
import click
import git
from packaging import version
from redbaron import RedBaron
@click.command()
def run():
"""An interactive script to walk through the initial stages of creating a
release, including creating release branch, updating changelog and pushing to
GitHub.
Requires the dev dependencies be installed, which can be done via:
pip install -e .[dev]
"""
# Make sure we're in a git repo.
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
raise click.ClickException("Not in Synapse repo.")
if repo.is_dirty():
raise click.ClickException("Uncommitted changes exist.")
click.secho("Updating git repo...")
repo.remote().fetch()
# Parse the AST and load the `__version__` node so that we can edit it
# later.
with open("synapse/__init__.py") as f:
red = RedBaron(f.read())
version_node = None
for node in red:
if node.type != "assignment":
continue
if node.target.type != "name":
continue
if node.target.value != "__version__":
continue
version_node = node
break
if not version_node:
print("Failed to find '__version__' definition in synapse/__init__.py")
sys.exit(1)
# Parse the current version.
current_version = version.parse(version_node.value.value.strip('"'))
assert isinstance(current_version, version.Version)
# Figure out what sort of release we're doing and calcuate the new version.
rc = click.confirm("RC", default=True)
if current_version.pre:
# If the current version is an RC we don't need to bump any of the
# version numbers (other than the RC number).
base_version = "{}.{}.{}".format(
current_version.major,
current_version.minor,
current_version.micro,
)
if rc:
new_version = "{}.{}.{}rc{}".format(
current_version.major,
current_version.minor,
current_version.micro,
current_version.pre[1] + 1,
)
else:
new_version = base_version
else:
# If this is a new release cycle then we need to know if its a major
# version bump or a hotfix.
release_type = click.prompt(
"Release type",
type=click.Choice(("major", "hotfix")),
show_choices=True,
default="major",
)
if release_type == "major":
base_version = new_version = "{}.{}.{}".format(
current_version.major,
current_version.minor + 1,
0,
)
if rc:
new_version = "{}.{}.{}rc1".format(
current_version.major,
current_version.minor + 1,
0,
)
else:
base_version = new_version = "{}.{}.{}".format(
current_version.major,
current_version.minor,
current_version.micro + 1,
)
if rc:
new_version = "{}.{}.{}rc1".format(
current_version.major,
current_version.minor,
current_version.micro + 1,
)
# Confirm the calculated version is OK.
if not click.confirm(f"Create new version: {new_version}?", default=True):
click.get_current_context().abort()
# Switch to the release branch.
release_branch_name = f"release-v{base_version}"
release_branch = find_ref(repo, release_branch_name)
if release_branch:
if release_branch.is_remote():
# If the release branch only exists on the remote we check it out
# locally.
repo.git.checkout(release_branch_name)
release_branch = repo.active_branch
else:
# If a branch doesn't exist we create one. We ask which one branch it
# should be based off, defaulting to sensible values depending on the
# release type.
if current_version.is_prerelease:
default = release_branch_name
elif release_type == "major":
default = "develop"
else:
default = "master"
branch_name = click.prompt(
"Which branch should the release be based on?", default=default
)
base_branch = find_ref(repo, branch_name)
if not base_branch:
print(f"Could not find base branch {branch_name}!")
click.get_current_context().abort()
# Check out the base branch and ensure it's up to date
repo.head.reference = base_branch
repo.head.reset(index=True, working_tree=True)
if not base_branch.is_remote():
update_branch(repo)
# Create the new release branch
release_branch = repo.create_head(release_branch_name, commit=base_branch)
# Switch to the release branch and ensure its up to date.
repo.git.checkout(release_branch_name)
update_branch(repo)
# Update the `__version__` variable and write it back to the file.
version_node.value = '"' + new_version + '"'
with open("synapse/__init__.py", "w") as f:
f.write(red.dumps())
# Generate changelogs
subprocess.run("python3 -m towncrier", shell=True)
# Generate debian changelogs if its not an RC.
if not rc:
subprocess.run(
f'dch -M -v {new_version} "New synapse release {new_version}."', shell=True
)
subprocess.run('dch -M -r -D stable ""', shell=True)
# Show the user the changes and ask if they want to edit the change log.
repo.git.add("-u")
subprocess.run("git diff --cached", shell=True)
if click.confirm("Edit changelog?", default=False):
click.edit(filename="CHANGES.md")
# Commit the changes.
repo.git.add("-u")
repo.git.commit(f"-m {new_version}")
# We give the option to bail here in case the user wants to make sure things
# are OK before pushing.
if not click.confirm("Push branch to github?", default=True):
print("")
print("Run when ready to push:")
print("")
print(f"\tgit push -u {repo.remote().name} {repo.active_branch.name}")
print("")
sys.exit(0)
# Otherwise, push and open the changelog in the browser.
repo.git.push("-u", repo.remote().name, repo.active_branch.name)
click.launch(
f"https://github.com/matrix-org/synapse/blob/{repo.active_branch.name}/CHANGES.md"
)
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
"""Find the branch/ref, looking first locally then in the remote."""
if ref_name in repo.refs:
return repo.refs[ref_name]
elif ref_name in repo.remote().refs:
return repo.remote().refs[ref_name]
else:
return None
def update_branch(repo: git.Repo):
"""Ensure branch is up to date if it has a remote"""
if repo.active_branch.tracking_branch():
repo.git.merge(repo.active_branch.tracking_branch().name)
if __name__ == "__main__":
run()

View file

@ -18,16 +18,15 @@ ignore =
# E203: whitespace before ':' (which is contrary to pep8?) # E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def # E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us) # E501: Line too long (black enforces this for us)
# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes) # B007: Subsection of the bugbear suite (TODO: add in remaining fixes)
ignore=W503,W504,E203,E731,E501,B006,B007,B008 ignore=W503,W504,E203,E731,E501,B007
[isort] [isort]
line_length = 88 line_length = 88
sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
default_section=THIRDPARTY default_section=THIRDPARTY
known_first_party = synapse known_first_party = synapse
known_tests=tests known_tests=tests
known_compat = mock
known_twisted=twisted,OpenSSL known_twisted=twisted,OpenSSL
multi_line_output=3 multi_line_output=3
include_trailing_comma=true include_trailing_comma=true

View file

@ -103,6 +103,13 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8", "flake8",
] ]
CONDITIONAL_REQUIREMENTS["dev"] = CONDITIONAL_REQUIREMENTS["lint"] + [
# The following are used by the release script
"click==7.1.2",
"redbaron==0.9.2",
"GitPython==3.1.14",
]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"] CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
# Dependencies which are exclusively required by unit test code. This is # Dependencies which are exclusively required by unit test code. This is
@ -110,7 +117,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
# Tests assume that all optional dependencies are installed. # Tests assume that all optional dependencies are installed.
# #
# parameterized_class decorator was introduced in parameterized 0.7.0 # parameterized_class decorator was introduced in parameterized 0.7.0
CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"] CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
setup( setup(
name="matrix-synapse", name="matrix-synapse",
@ -123,13 +130,12 @@ setup(
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
long_description_content_type="text/x-rst", long_description_content_type="text/x-rst",
python_requires="~=3.5", python_requires="~=3.6",
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",
"Topic :: Communications :: Chat", "Topic :: Communications :: Chat",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",

View file

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.31.0" __version__ = "1.32.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -59,6 +59,8 @@ class JoinRules:
KNOCK = "knock" KNOCK = "knock"
INVITE = "invite" INVITE = "invite"
PRIVATE = "private" PRIVATE = "private"
# As defined for MSC3083.
MSC3083_RESTRICTED = "restricted"
class LoginType: class LoginType:
@ -71,6 +73,11 @@ class LoginType:
DUMMY = "m.login.dummy" DUMMY = "m.login.dummy"
# This is used in the `type` parameter for /register when called by
# an appservice to register a new user.
APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service"
class EventTypes: class EventTypes:
Member = "m.room.member" Member = "m.room.member"
Create = "m.room.create" Create = "m.room.create"

View file

@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.storage.databases.main import DataStore
from synapse.types import Requester from synapse.types import Requester
from synapse.util import Clock from synapse.util import Clock
@ -31,10 +32,13 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
""" """
def __init__(self, clock: Clock, rate_hz: float, burst_count: int): def __init__(
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
):
self.clock = clock self.clock = clock
self.rate_hz = rate_hz self.rate_hz = rate_hz
self.burst_count = burst_count self.burst_count = burst_count
self.store = store
# A ordered dictionary keeping track of actions, when they were last # A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type # performed and how often. Each entry is a mapping from a key of arbitrary type
@ -46,45 +50,10 @@ class Ratelimiter:
OrderedDict() OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]] ) # type: OrderedDict[Hashable, Tuple[float, int, float]]
def can_requester_do_action( async def can_do_action(
self, self,
requester: Requester, requester: Optional[Requester],
rate_hz: Optional[float] = None, key: Optional[Hashable] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the requester perform the action?
Args:
requester: The requester to key off when rate limiting. The user property
will be used.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0
return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
)
def can_do_action(
self,
key: Hashable,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
@ -92,9 +61,16 @@ class Ratelimiter:
) -> Tuple[bool, float]: ) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action? """Can the entity (e.g. user or IP address) perform the action?
Checks if the user has ratelimiting disabled in the database by looking
for null/zero values in the `ratelimit_override` table. (Non-zero
values aren't honoured, as they're specific to the event sending
ratelimiter, rather than all ratelimiters)
Args: Args:
key: The key we should use when rate limiting. Can be a user ID requester: The requester that is doing the action, if any. Used to check
(when sending events), an IP address, etc. if the user has ratelimits disabled in the database.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
rate_hz: The long term number of actions that can be performed in a second. rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
@ -109,6 +85,30 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next. * The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero -1 if rate_hz is less than or equal to zero
""" """
if key is None:
if not requester:
raise ValueError("Must supply at least one of `requester` or `key`")
key = requester.user.to_string()
if requester:
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0
# Check if ratelimiting has been disabled for the user.
#
# Note that we don't use the returned rate/burst count, as the table
# is specifically for the event sending ratelimiter. Instead, we
# only use it to (somewhat cheekily) infer whether the user should
# be subject to any rate limiting or not.
override = await self.store.get_ratelimit_for_user(
requester.authenticated_entity
)
if override and not override.messages_per_second:
return True, -1.0
# Override default values if set # Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz rate_hz = rate_hz if rate_hz is not None else self.rate_hz
@ -175,9 +175,10 @@ class Ratelimiter:
else: else:
del self.actions[key] del self.actions[key]
def ratelimit( async def ratelimit(
self, self,
key: Hashable, requester: Optional[Requester],
key: Optional[Hashable] = None,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
@ -185,8 +186,16 @@ class Ratelimiter:
): ):
"""Checks if an action can be performed. If not, raises a LimitExceededError """Checks if an action can be performed. If not, raises a LimitExceededError
Checks if the user has ratelimiting disabled in the database by looking
for null/zero values in the `ratelimit_override` table. (Non-zero
values aren't honoured, as they're specific to the event sending
ratelimiter, rather than all ratelimiters)
Args: Args:
key: An arbitrary key used to classify an action requester: The requester that is doing the action, if any. Used to check for
if the user has ratelimits disabled.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
rate_hz: The long term number of actions that can be performed in a second. rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
@ -201,7 +210,8 @@ class Ratelimiter:
""" """
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
allowed, time_allowed = self.can_do_action( allowed, time_allowed = await self.can_do_action(
requester,
key, key,
rate_hz=rate_hz, rate_hz=rate_hz,
burst_count=burst_count, burst_count=burst_count,

View file

@ -57,7 +57,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool) enforce_key_validity = attr.ib(type=bool)
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool) special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow: # Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@ -69,6 +69,8 @@ class RoomVersion:
limit_notifications_power_levels = attr.ib(type=bool) limit_notifications_power_levels = attr.ib(type=bool)
# MSC2174/MSC2176: Apply updated redaction rules algorithm. # MSC2174/MSC2176: Apply updated redaction rules algorithm.
msc2176_redaction_rules = attr.ib(type=bool) msc2176_redaction_rules = attr.ib(type=bool)
# MSC3083: Support the 'restricted' join_rule.
msc3083_join_rules = attr.ib(type=bool)
class RoomVersions: class RoomVersions:
@ -82,6 +84,7 @@ class RoomVersions:
strict_canonicaljson=False, strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False,
) )
V2 = RoomVersion( V2 = RoomVersion(
"2", "2",
@ -93,6 +96,7 @@ class RoomVersions:
strict_canonicaljson=False, strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False,
) )
V3 = RoomVersion( V3 = RoomVersion(
"3", "3",
@ -104,6 +108,7 @@ class RoomVersions:
strict_canonicaljson=False, strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False,
) )
V4 = RoomVersion( V4 = RoomVersion(
"4", "4",
@ -115,6 +120,7 @@ class RoomVersions:
strict_canonicaljson=False, strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False,
) )
V5 = RoomVersion( V5 = RoomVersion(
"5", "5",
@ -126,6 +132,7 @@ class RoomVersions:
strict_canonicaljson=False, strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False,
) )
V6 = RoomVersion( V6 = RoomVersion(
"6", "6",
@ -137,6 +144,7 @@ class RoomVersions:
strict_canonicaljson=True, strict_canonicaljson=True,
limit_notifications_power_levels=True, limit_notifications_power_levels=True,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False,
) )
MSC2176 = RoomVersion( MSC2176 = RoomVersion(
"org.matrix.msc2176", "org.matrix.msc2176",
@ -148,6 +156,19 @@ class RoomVersions:
strict_canonicaljson=True, strict_canonicaljson=True,
limit_notifications_power_levels=True, limit_notifications_power_levels=True,
msc2176_redaction_rules=True, msc2176_redaction_rules=True,
msc3083_join_rules=False,
)
MSC3083 = RoomVersion(
"org.matrix.msc3083",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=True,
) )
@ -162,4 +183,5 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V6, RoomVersions.V6,
RoomVersions.MSC2176, RoomVersions.MSC2176,
) )
# Note that we do not include MSC3083 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View file

@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.hs = hs self.hs = hs
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
# The number of ongoing syncs on this process, by user id. # The number of ongoing syncs on this process, by user id.
@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing() return _user_syncing()
async def notify_from_replication(self, states, stream_id): async def notify_from_replication(self, states, stream_id):
parties = await get_interested_parties(self.store, states) parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties room_ids_to_states, users_to_states = parties
self.notifier.on_new_event( self.notifier.on_new_event(

View file

@ -49,7 +49,7 @@ This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
import logging import logging
from typing import List from typing import List, Optional
from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase from synapse.events import EventBase
@ -191,11 +191,11 @@ class _TransactionController:
self, self,
service: ApplicationService, service: ApplicationService,
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict] = [], ephemeral: Optional[List[JsonDict]] = None,
): ):
try: try:
txn = await self.store.create_appservice_txn( txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral service=service, events=events, ephemeral=ephemeral or []
) )
service_is_up = await self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:

View file

@ -1,4 +1,4 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,38 +12,131 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes import logging
from typing import Iterable
from ._base import Config from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
class ApiConfig(Config): class ApiConfig(Config):
section = "api" section = "api"
def read_config(self, config, **kwargs): def read_config(self, config: JsonDict, **kwargs):
self.room_invite_state_types = config.get( validate_config(_MAIN_SCHEMA, config, ())
"room_invite_state_types", self.room_prejoin_state = list(self._get_prejoin_state_types(config))
[
def generate_config_section(cls, **kwargs) -> str:
formatted_default_state_types = "\n".join(
" # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
)
return """\
## API Configuration ##
# Controls for the state that is shared with users who receive an invite
# to a room
#
room_prejoin_state:
# By default, the following state event types are shared with users who
# receive invites to the room:
#
%(formatted_default_state_types)s
#
# Uncomment the following to disable these defaults (so that only the event
# types listed in 'additional_event_types' are shared). Defaults to 'false'.
#
#disable_default_event_types: true
# Additional state event types to share with users when they are invited
# to a room.
#
# By default, this list is empty (so only the default event types are shared).
#
#additional_event_types:
# - org.example.custom.event.type
""" % {
"formatted_default_state_types": formatted_default_state_types
}
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
"""Get the event types to include in the prejoin state
Parses the config and returns an iterable of the event types to be included.
"""
room_prejoin_state_config = config.get("room_prejoin_state") or {}
# backwards-compatibility support for room_invite_state_types
if "room_invite_state_types" in config:
# if both "room_invite_state_types" and "room_prejoin_state" are set, then
# we don't really know what to do.
if room_prejoin_state_config:
raise ConfigError(
"Can't specify both 'room_invite_state_types' and 'room_prejoin_state' "
"in config"
)
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
yield from config["room_invite_state_types"]
return
if not room_prejoin_state_config.get("disable_default_event_types"):
yield from _DEFAULT_PREJOIN_STATE_TYPES
if self.spaces_enabled:
# MSC1772 suggests adding m.room.create to the prejoin state
yield EventTypes.Create
yield from room_prejoin_state_config.get("additional_event_types", [])
_ROOM_INVITE_STATE_TYPES_WARNING = """\
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
and replaced with 'room_prejoin_state'. New features may not work correctly
unless 'room_invite_state_types' is removed. See the sample configuration file for
details of 'room_prejoin_state'.
--------------------------------------------------------------------------------
"""
_DEFAULT_PREJOIN_STATE_TYPES = [
EventTypes.JoinRules, EventTypes.JoinRules,
EventTypes.CanonicalAlias, EventTypes.CanonicalAlias,
EventTypes.RoomAvatar, EventTypes.RoomAvatar,
EventTypes.RoomEncryption, EventTypes.RoomEncryption,
EventTypes.Name, EventTypes.Name,
], ]
)
def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
# A list of event types that will be included in the room_invite_state # room_prejoin_state can either be None (as it is in the default config), or
# # an object containing other config settings
#room_invite_state_types: _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
# - "{JoinRules}" "oneOf": [
# - "{CanonicalAlias}" {
# - "{RoomAvatar}" "type": "object",
# - "{RoomEncryption}" "properties": {
# - "{Name}" "disable_default_event_types": {"type": "boolean"},
""".format( "additional_event_types": {
**vars(EventTypes) "type": "array",
) "items": {"type": "string"},
},
},
},
{"type": "null"},
]
}
# the legacy room_invite_state_types setting
_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}}
_MAIN_SCHEMA = {
"type": "object",
"properties": {
"room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
"room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
},
}

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.config._base import Config from synapse.config._base import Config
from synapse.types import JsonDict from synapse.types import JsonDict
@ -27,7 +28,11 @@ class ExperimentalConfig(Config):
# MSC2858 (multiple SSO identity providers) # MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
# Spaces (MSC1772, MSC2946, etc)
# Spaces (MSC1772, MSC2946, MSC3083, etc)
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
if self.spaces_enabled:
KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
# MSC3026 (busy presence state) # MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict, Optional
from ._base import Config from ._base import Config
@ -21,8 +21,10 @@ class RateLimitConfig:
def __init__( def __init__(
self, self,
config: Dict[str, float], config: Dict[str, float],
defaults={"per_second": 0.17, "burst_count": 3.0}, defaults: Optional[Dict[str, float]] = None,
): ):
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
self.per_second = config.get("per_second", defaults["per_second"]) self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) self.burst_count = int(config.get("burst_count", defaults["burst_count"]))

View file

@ -298,9 +298,9 @@ class RegistrationConfig(Config):
# #
#allowed_local_3pids: #allowed_local_3pids:
# - medium: email # - medium: email
# pattern: '.*@matrix\\.org' # pattern: '^[^@]+@matrix\\.org$'
# - medium: email # - medium: email
# pattern: '.*@vector\\.im' # pattern: '^[^@]+@vector\\.im$'
# - medium: msisdn # - medium: msisdn
# pattern: '\\+44' # pattern: '\\+44'

View file

@ -27,6 +27,7 @@ import yaml
from netaddr import AddrFormatError, IPNetwork, IPSet from netaddr import AddrFormatError, IPNetwork, IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -238,8 +239,21 @@ class ServerConfig(Config):
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
# Whether to enable user presence. # Whether to enable user presence.
presence_config = config.get("presence") or {}
self.use_presence = presence_config.get("enabled")
if self.use_presence is None:
self.use_presence = config.get("use_presence", True) self.use_presence = config.get("use_presence", True)
# Custom presence router module
self.presence_router_module_class = None
self.presence_router_config = None
presence_router_config = presence_config.get("presence_router")
if presence_router_config:
(
self.presence_router_module_class,
self.presence_router_config,
) = load_module(presence_router_config, ("presence", "presence_router"))
# Whether to update the user directory or not. This should be set to # Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker # false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True) self.update_user_directory = config.get("update_user_directory", True)
@ -834,9 +848,28 @@ class ServerConfig(Config):
# #
#soft_file_limit: 0 #soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver. # Presence tracking allows users to see the state (e.g online/offline)
# of other local and remote users.
# #
#use_presence: false presence:
# Uncomment to disable presence tracking on this homeserver. This option
# replaces the previous top-level 'use_presence' option.
#
#enabled: false
# Presence routers are third-party modules that can specify additional logic
# to where presence updates from users are routed.
#
presence_router:
# The custom module's class. Uncomment to use a custom presence router module.
#
#module: "my_custom_router.PresenceRouter"
# Configuration options of the custom module. Refer to your module's
# documentation for available options.
#
#config:
# example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars, # Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to # display names) of other users through the client API. Defaults to

View file

@ -270,7 +270,7 @@ class TlsConfig(Config):
tls_certificate_path, tls_certificate_path,
tls_private_key_path, tls_private_key_path,
acme_domain, acme_domain,
**kwargs **kwargs,
): ):
"""If the acme_domain is specified acme will be enabled. """If the acme_domain is specified acme will be enabled.
If the TLS paths are not specified the default will be certs in the If the TLS paths are not specified the default will be certs in the

View file

@ -162,7 +162,7 @@ def check(
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
_is_membership_change_allowed(event, auth_events) _is_membership_change_allowed(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
return return
@ -220,8 +220,19 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def _is_membership_change_allowed( def _is_membership_change_allowed(
event: EventBase, auth_events: StateMap[EventBase] room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
) -> None: ) -> None:
"""
Confirms that the event which changes membership is an allowed change.
Args:
room_version: The version of the room.
event: The event to check.
auth_events: The current auth events of the room.
Raises:
AuthError if the event is not allowed.
"""
membership = event.content["membership"] membership = event.content["membership"]
# Check if this is the room creator joining: # Check if this is the room creator joining:
@ -315,14 +326,19 @@ def _is_membership_change_allowed(
if user_level < invite_level: if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users") raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership: elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were: # Joins are valid iff caller == target and:
# invited: They are accepting the invitation # * They are not banned.
# joined: It's a NOOP # * They are accepting a previously sent invitation.
# * They are already joined (it's a NOOP).
# * The room is public or restricted.
if event.user_id != target_user_id: if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.") raise AuthError(403, "Cannot force another user to join.")
elif target_banned: elif target_banned:
raise AuthError(403, "You are banned from this room") raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC: elif join_rule == JoinRules.PUBLIC or (
room_version.msc3083_join_rules
and join_rule == JoinRules.MSC3083_RESTRICTED
):
pass pass
elif join_rule == JoinRules.INVITE: elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited: if not caller_in_room and not caller_invited:

View file

@ -330,9 +330,11 @@ class FrozenEvent(EventBase):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: JsonDict = {}, internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
): ):
internal_metadata_dict = internal_metadata_dict or {}
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -386,9 +388,11 @@ class FrozenEventV2(EventBase):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: JsonDict = {}, internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
): ):
internal_metadata_dict = internal_metadata_dict or {}
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
def make_event_from_dict( def make_event_from_dict(
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion = RoomVersions.V1, room_version: RoomVersion = RoomVersions.V1,
internal_metadata_dict: JsonDict = {}, internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
) -> EventBase: ) -> EventBase:
"""Construct an EventBase from the given event dict""" """Construct an EventBase from the given event dict"""
event_type = _event_type_from_format_version(room_version.event_format) event_type = _event_type_from_format_version(room_version.event_format)
return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason) return event_type(
event_dict, room_version, internal_metadata_dict or {}, rejected_reason
)

View file

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
from synapse.api.presence import UserPresenceState
if TYPE_CHECKING:
from synapse.server import HomeServer
class PresenceRouter:
"""
A module that the homeserver will call upon to help route user presence updates to
additional destinations. If a custom presence router is configured, calls will be
passed to that instead.
"""
ALL_USERS = "ALL"
def __init__(self, hs: "HomeServer"):
self.custom_presence_router = None
# Check whether a custom presence router module has been configured
if hs.config.presence_router_module_class:
# Initialise the module
self.custom_presence_router = hs.config.presence_router_module_class(
config=hs.config.presence_router_config, module_api=hs.get_module_api()
)
# Ensure the module has implemented the required methods
required_methods = ["get_users_for_states", "get_interested_users"]
for method_name in required_methods:
if not hasattr(self.custom_presence_router, method_name):
raise Exception(
"PresenceRouter module '%s' must implement all required methods: %s"
% (
hs.config.presence_router_module_class.__name__,
", ".join(required_methods),
)
)
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
"""
Given an iterable of user presence updates, determine where each one
needs to go.
Args:
state_updates: An iterable of user presence state updates.
Returns:
A dictionary of user_id -> set of UserPresenceState, indicating which
presence updates each user should receive.
"""
if self.custom_presence_router is not None:
# Ask the custom module
return await self.custom_presence_router.get_users_for_states(
state_updates=state_updates
)
# Don't include any extra destinations for presence updates
return {}
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
that this user should receive all incoming local and remote presence updates.
Note that this method will only be called for local users, but can return users
that are local or remote.
Args:
user_id: A user requesting presence updates.
Returns:
A set of user IDs to return presence updates for, or ALL_USERS to return all
known updates.
"""
if self.custom_presence_router is not None:
# Ask the custom module for interested users
return await self.custom_presence_router.get_interested_users(
user_id=user_id
)
# A custom presence router is not defined.
# Don't report any additional interested users
return set()

View file

@ -102,7 +102,7 @@ class FederationClient(FederationBase):
max_len=1000, max_len=1000,
expiry_ms=120 * 1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) ) # type: ExpiringCache[str, EventBase]
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""

View file

@ -739,22 +739,20 @@ class FederationServer(FederationBase):
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self): def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
): ) -> None:
ret = await self.handler.exchange_third_party_invite( await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed sender_user_id, target_user_id, room_id, signed
) )
return ret
async def on_exchange_third_party_invite_request(self, event_dict: Dict): async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
ret = await self.handler.on_exchange_third_party_invite_request(event_dict) await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret
async def check_server_matches_acl(self, server_name: str, room_id: str): async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
"""Check if the given server is allowed by the server ACLs in the room """Check if the given server is allowed by the server ACLs in the room
Args: Args:
@ -870,6 +868,7 @@ class FederationHandlerRegistry:
# A rate limiter for incoming room key requests per origin. # A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter( self._room_key_request_rate_limiter = Ratelimiter(
store=hs.get_datastore(),
clock=self.clock, clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second, rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count, burst_count=self.config.rc_key_requests.burst_count,
@ -877,7 +876,7 @@ class FederationHandlerRegistry:
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
): ) -> None:
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation EDU of the given type. federation EDU of the given type.
@ -896,7 +895,7 @@ class FederationHandlerRegistry:
def register_query_handler( def register_query_handler(
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
): ) -> None:
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation query of the given type. federation query of the given type.
@ -914,15 +913,17 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str): def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
"""Register that the EDU handler is on a different instance than master.""" """Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name] self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
) -> None:
"""Register that the EDU handler is on multiple instances.""" """Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict): async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence: if not self.config.use_presence and edu_type == EduTypes.Presence:
return return
@ -930,7 +931,9 @@ class FederationHandlerRegistry:
# the limit, drop them. # the limit, drop them.
if ( if (
edu_type == EduTypes.RoomKeyRequest edu_type == EduTypes.RoomKeyRequest
and not self._room_key_request_rate_limiter.can_do_action(origin) and not await self._room_key_request_rate_limiter.can_do_action(
None, origin
)
): ):
return return

View file

@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.events.presence_router import PresenceRouter
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self._presence_router = None # type: Optional[PresenceRouter]
self._transaction_manager = TransactionManager(hs) self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender):
"""Given a list of states populate self.pending_presence_by_dest and """Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination poke to send a new transaction to each destination
""" """
hosts_and_states = await get_interested_remotes(self.store, states, self.state) # We pull the presence router here instead of __init__
# to prevent a dependency cycle:
#
# AuthHandler -> Notifier -> FederationSender
# -> PresenceRouter -> ModuleApi -> AuthHandler
if self._presence_router is None:
self._presence_router = self.hs.get_presence_router()
assert self._presence_router is not None
hosts_and_states = await get_interested_remotes(
self.store,
self._presence_router,
states,
self.state,
)
for destinations, states in hosts_and_states: for destinations, states in hosts_and_states:
for destination in destinations: for destination in destinations:
@ -717,16 +734,18 @@ class FederationSender(AbstractFederationSender):
self._catchup_after_startup_timer = None self._catchup_after_startup_timer = None
break break
last_processed = destinations_to_wake[-1]
destinations_to_wake = [ destinations_to_wake = [
d d
for d in destinations_to_wake for d in destinations_to_wake
if self._federation_shard_config.should_handle(self._instance_name, d) if self._federation_shard_config.should_handle(self._instance_name, d)
] ]
for last_processed in destinations_to_wake: for destination in destinations_to_wake:
logger.info( logger.info(
"Destination %s has outstanding catch-up, waking up.", "Destination %s has outstanding catch-up, waking up.",
last_processed, last_processed,
) )
self.wake_destination(last_processed) self.wake_destination(destination)
await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC) await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)

View file

@ -29,6 +29,7 @@ from synapse.api.presence import UserPresenceState
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.units import Edu from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.logging.opentracing import SynapseTags, set_tag
from synapse.metrics import sent_transactions_counter from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
@ -557,6 +558,13 @@ class PerDestinationQueue:
contents, stream_id = await self._store.get_new_device_msgs_for_remote( contents, stream_id = await self._store.get_new_device_msgs_for_remote(
self._destination, last_device_stream_id, to_device_stream_id, limit self._destination, last_device_stream_id, to_device_stream_id, limit
) )
for content in contents:
message_id = content.get("message_id")
if not message_id:
continue
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
edus = [ edus = [
Edu( Edu(
origin=self._server_name, origin=self._server_name,

View file

@ -425,13 +425,9 @@ class FederationSendServlet(BaseFederationServlet):
logger.exception(e) logger.exception(e)
return 400, {"error": "Invalid transaction"} return 400, {"error": "Invalid transaction"}
try:
code, response = await self.handler.on_incoming_transaction( code, response = await self.handler.on_incoming_transaction(
origin, transaction_data origin, transaction_data
) )
except Exception:
logger.exception("on_incoming_transaction failed")
raise
return code, response return code, response
@ -620,8 +616,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id): async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(content) await self.handler.on_exchange_third_party_invite_request(content)
return 200, content return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):

View file

@ -18,6 +18,7 @@ server protocol.
""" """
import logging import logging
from typing import Optional
import attr import attr
@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject):
"pdus", "pdus",
] ]
def __init__(self, transaction_id=None, pdus=[], **kwargs): def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
"""If we include a list of pdus then we decode then as PDU's """If we include a list of pdus then we decode then as PDU's
automatically. automatically.
""" """
@ -107,7 +108,7 @@ class Transaction(JsonEncodedObject):
if "edus" in kwargs and not kwargs["edus"]: if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"] del kwargs["edus"]
super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
@staticmethod @staticmethod
def create_new(pdus, **kwargs): def create_new(pdus, **kwargs):

View file

@ -49,7 +49,7 @@ class BaseHandler:
# The rate_hz and burst_count are overridden on a per-user basis # The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter( self.request_ratelimiter = Ratelimiter(
clock=self.clock, rate_hz=0, burst_count=0 store=self.store, clock=self.clock, rate_hz=0, burst_count=0
) )
self._rc_message = self.hs.config.rc_message self._rc_message = self.hs.config.rc_message
@ -57,6 +57,7 @@ class BaseHandler:
# by the presence of rate limits in the config # by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction: if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter( self.admin_redaction_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second, rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count, burst_count=self.hs.config.rc_admin_redaction.burst_count,
@ -91,11 +92,6 @@ class BaseHandler:
if app_service is not None: if app_service is not None:
return # do not ratelimit app service senders return # do not ratelimit app service senders
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return
messages_per_second = self._rc_message.per_second messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count burst_count = self._rc_message.burst_count
@ -113,11 +109,11 @@ class BaseHandler:
if is_admin_redaction and self.admin_redaction_ratelimiter: if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate # If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash # ratelimiter as to not have user_ids clash
self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
else: else:
# Override rate and burst count per-user # Override rate and burst count per-user
self.request_ratelimiter.ratelimit( await self.request_ratelimiter.ratelimit(
user_id, requester,
rate_hz=messages_per_second, rate_hz=messages_per_second,
burst_count=burst_count, burst_count=burst_count,
update=update, update=update,

View file

@ -18,7 +18,7 @@ import email.utils
import logging import logging
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -241,7 +241,10 @@ class AccountValidityHandler:
return True return True
async def renew_account_for_user( async def renew_account_for_user(
self, user_id: str, expiration_ts: int = None, email_sent: bool = False self,
user_id: str,
expiration_ts: Optional[int] = None,
email_sent: bool = False,
) -> int: ) -> int:
"""Renews the account attached to a given user by pushing back the """Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's expiration date by the current validity period in the server's

View file

@ -182,7 +182,7 @@ class ApplicationServicesHandler:
self, self,
stream_key: str, stream_key: str,
new_token: Optional[int], new_token: Optional[int],
users: Collection[Union[str, UserID]] = [], users: Optional[Collection[Union[str, UserID]]] = None,
): ):
"""This is called by the notifier in the background """This is called by the notifier in the background
when a ephemeral event handled by the homeserver. when a ephemeral event handled by the homeserver.
@ -215,7 +215,7 @@ class ApplicationServicesHandler:
# We only start a new background process if necessary rather than # We only start a new background process if necessary rather than
# optimistically (to cut down on overhead). # optimistically (to cut down on overhead).
self._notify_interested_services_ephemeral( self._notify_interested_services_ephemeral(
services, stream_key, new_token, users services, stream_key, new_token, users or []
) )
@wrap_as_background_process("notify_interested_services_ephemeral") @wrap_as_background_process("notify_interested_services_ephemeral")

View file

@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second, rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count, burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
# Ratelimitier for failed /login attempts # Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter( self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second, rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count, burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts # Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
# build a list of supported flows # build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types( supported_ui_auth_types = await self._get_available_ui_auth_types(
@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
) )
except LoginError: except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise). # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) await self._failed_uia_attempts_ratelimiter.can_do_action(
requester,
)
raise raise
# find the completed login type # find the completed login type
@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
# We also apply account rate limiting using the 3PID as a key, as # We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID. # otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit: if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit( await self._failed_login_attempts_ratelimiter.ratelimit(
(medium, address), update=False None, (medium, address), update=False
) )
# Check for login providers that support 3pid login types # Check for login providers that support 3pid login types
@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
# this code path, which is fine as then the per-user ratelimit # this code path, which is fine as then the per-user ratelimit
# will kick in below. # will kick in below.
if ratelimit: if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action( await self._failed_login_attempts_ratelimiter.can_do_action(
(medium, address) None, (medium, address)
) )
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
# Check if we've hit the failed ratelimit (but don't update it) # Check if we've hit the failed ratelimit (but don't update it)
if ratelimit: if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit( await self._failed_login_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False None, qualified_user_id.lower(), update=False
) )
try: try:
@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
# exception and masking the LoginError. The actual ratelimiting # exception and masking the LoginError. The actual ratelimiting
# should have happened above. # should have happened above.
if ratelimit: if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action( await self._failed_login_attempts_ratelimiter.can_do_action(
qualified_user_id.lower() None, qualified_user_id.lower()
) )
raise raise

View file

@ -631,7 +631,7 @@ class DeviceListUpdater:
max_len=10000, max_len=10000,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
iterable=True, iterable=True,
) ) # type: ExpiringCache[str, Set[str]]
# Attempt to resync out of sync device lists every 30s. # Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False self._resync_retry_in_progress = False
@ -760,7 +760,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full """Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta. resync, or whether we have enough data that we can just apply the delta.
""" """
seen_updates = self._seen_updates.get(user_id, set()) seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

View file

@ -21,10 +21,10 @@ from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
SynapseTags,
get_active_span_text_map, get_active_span_text_map,
log_kv, log_kv,
set_tag, set_tag,
start_active_span,
) )
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
@ -81,6 +81,7 @@ class DeviceMessageHandler:
) )
self._ratelimiter = Ratelimiter( self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second, rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count, burst_count=hs.config.rc_key_requests.burst_count,
@ -182,7 +183,10 @@ class DeviceMessageHandler:
) -> None: ) -> None:
sender_user_id = requester.user.to_string() sender_user_id = requester.user.to_string()
set_tag("number_of_messages", len(messages)) message_id = random_string(16)
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
@ -191,8 +195,8 @@ class DeviceMessageHandler:
if ( if (
message_type == EduTypes.RoomKeyRequest message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id and user_id != sender_user_id
and self._ratelimiter.can_do_action( and await self._ratelimiter.can_do_action(
(sender_user_id, requester.device_id) requester, (sender_user_id, requester.device_id)
) )
): ):
continue continue
@ -204,23 +208,27 @@ class DeviceMessageHandler:
"content": message_content, "content": message_content,
"type": message_type, "type": message_type,
"sender": sender_user_id, "sender": sender_user_id,
"message_id": message_id,
} }
for device_id, message_content in by_device.items() for device_id, message_content in by_device.items()
} }
if messages_by_device: if messages_by_device:
local_messages[user_id] = messages_by_device local_messages[user_id] = messages_by_device
log_kv(
{
"user_id": user_id,
"device_id": list(messages_by_device),
}
)
else: else:
destination = get_domain_from_id(user_id) destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
context = get_active_span_text_map() context = get_active_span_text_map()
remote_edu_contents = {} remote_edu_contents = {}
for destination, messages in remote_messages.items(): for destination, messages in remote_messages.items():
with start_active_span("to_device_for_user"): log_kv({"destination": destination})
set_tag("destination", destination)
remote_edu_contents[destination] = { remote_edu_contents[destination] = {
"messages": messages, "messages": messages,
"sender": sender_user_id, "sender": sender_user_id,
@ -229,7 +237,6 @@ class DeviceMessageHandler:
"org.matrix.opentracing_context": json_encoder.encode(context), "org.matrix.opentracing_context": json_encoder.encode(context),
} }
log_kv({"local_messages": local_messages})
stream_id = await self.store.add_messages_to_device_inbox( stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents local_messages, remote_edu_contents
) )
@ -238,7 +245,6 @@ class DeviceMessageHandler:
"to_device_key", stream_id, users=local_messages.keys() "to_device_key", stream_id, users=local_messages.keys()
) )
log_kv({"remote_messages": remote_messages})
if self.federation_sender: if self.federation_sender:
for destination in remote_messages.keys(): for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new # Enqueue a new federation transaction to send the new

View file

@ -38,7 +38,6 @@ from synapse.types import (
) )
from synapse.util import json_decoder, unwrapFirstError from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1008,7 +1007,7 @@ class E2eKeysHandler:
return signature_list, failures return signature_list, failures
async def _get_e2e_cross_signing_verify_key( async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Tuple[JsonDict, str, VerifyKey]: ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key. """Fetch locally or remotely query for a cross-signing public key.
@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
async def incoming_signing_key_update( async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict self, origin: str, edu_content: JsonDict
) -> None: ) -> None:

View file

@ -21,7 +21,17 @@ import itertools
import logging import logging
from collections.abc import Container from collections.abc import Container
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: async def on_receive_pdu(
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
) -> None:
"""Process a PDU received via a federation /send/ transaction, or """Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
Args: Args:
origin (str): server which initiated the /send/ transaction. Will origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state. be used to fetch missing events or state.
pdu (FrozenEvent): received PDU pdu: received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event. we pulled it as the result of a missing prev_event.
""" """
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state) await self._process_received_pdu(origin, pdu, state=state)
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
) -> None:
""" """
Args: Args:
origin (str): Origin of the pdu. Will be called to get the missing events origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu pdu: received pdu
prevs (set(str)): List of event ids which we are missing prevs: List of event ids which we are missing
min_depth (int): Minimum depth of events to return. min_depth: Minimum depth of events to return.
""" """
room_id = pdu.room_id room_id = pdu.room_id
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state: Optional[Iterable[EventBase]],
): ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender) logger.exception("Failed to resync device for %s", sender)
@log_function @log_function
async def backfill(self, dest, room_id, limit, extremities): async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id` """Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id) curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state """Get joined domains from state
Args: Args:
state (dict[tuple, FrozenEvent]): State map from type/state state: State map from type/state key to event.
key to event.
Returns: Returns:
list[tuple[str, int]]: Returns a list of servers with the Returns a list of servers with the lowest depth of their joins.
lowest depth of their joins. Sorted by lowest depth first. Sorted by lowest depth first.
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name
] ]
async def try_backfill(domains): async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill( success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains] [
dom
for dom, _ in likely_extremeties_domains
if dom not in tried_domains
]
) )
if success: if success:
return True return True
tried_domains.update(dom for dom, _ in likely_domains) tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False return False
async def _get_events_and_persist( async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str] self, destination: str, room_id: str, events: Iterable[str]
): ) -> None:
"""Fetch the given events from a server, and persist them as outliers. """Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the This function *does not* recursively get missing auth events of the
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos, event_infos,
) )
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev: EventBase) -> None:
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches. or cascade of event fetches.
Args: Args:
ev (synapse.events.EventBase): event to be checked ev: event to be checked
Returns: None
Raises: Raises:
SynapseError if the event does not pass muster SynapseError if the event does not pass muster
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
) )
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event): async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing. """Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution. Invites must be signed by the invitee's server before distribution.
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue) run_in_background(self._handle_queued_pdus, room_queue)
async def _handle_queued_pdus(self, room_queue): async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
"""Process PDUs which got queued up while we were busy send_joining. """Process PDUs which got queued up while we were busy send_joining.
Args: Args:
room_queue (list[FrozenEvent, str]): list of PDUs to be processed room_queue: list of PDUs to be processed and the servers that sent them
and the servers that sent them
""" """
for p, origin in room_queue: for p, origin in room_queue:
try: try:
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
return event return event
async def on_send_join_request(self, origin, pdu): async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and """We have received a join event for a room. Fully process it and
respond with the current state and auth chains. respond with the current state and auth chains.
""" """
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request( async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion self, origin: str, event: EventBase, room_version: RoomVersion
): ) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it. """We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event. Respond with the now signed event.
@ -1711,7 +1729,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler() member_handler = self.hs.get_room_member_handler()
# We don't rate limit based on room ID, as that should be done by # We don't rate limit based on room ID, as that should be done by
# sending server. # sending server.
member_handler.ratelimit_invite(None, event.state_key) await member_handler.ratelimit_invite(None, None, event.state_key)
# keep a record of the room version, if we don't yet know it. # keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a # (this may get overwritten if we later get a different room version in a
@ -1772,7 +1790,7 @@ class FederationHandler(BaseHandler):
room_id: str, room_id: str,
user_id: str, user_id: str,
membership: str, membership: str,
content: JsonDict = {}, content: JsonDict,
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]: ) -> Tuple[str, EventBase, RoomVersion]:
( (
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
return event return event
async def on_send_leave_request(self, origin, pdu): async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it.""" """ We have received a leave event for a room. Fully process it."""
event = pdu event = pdu
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
else: else:
return None return None
async def get_min_depth_for_context(self, context): async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context) return await self.store.get_min_depth(context)
async def _handle_new_event( async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False self,
): origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event( context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled origin, event, state=state, auth_events=auth_events, backfilled=backfilled
) )
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e) logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
):
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
await self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
return ret
async def on_get_missing_events( async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit self,
): origin: str,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin) in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth assumes that we have already processed all events in remote_auth
Params: Params:
local_auth (list) local_auth
remote_auth (list) remote_auth
Returns: Returns:
dict dict
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
@log_function @log_function
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
): ) -> None:
third_party_invite = {"signed": signed} third_party_invite = {"signed": signed}
event_dict = { event_dict = {
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context) await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite( async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context self,
): room_version: str,
event_dict: JsonDict,
event: EventBase,
context: EventContext,
) -> Tuple[EventBase, EventContext]:
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"], event.content["third_party_invite"]["signed"]["token"],
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config) EventValidator().validate_new(event, self.config)
return (event, context) return (event, context)
async def _check_signature(self, event, context): async def _check_signature(self, event: EventBase, context: EventContext) -> None:
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event: The m.room.member event to check
context (EventContext): context:
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
raise last_exception raise last_exception
async def _check_key_revocation(self, public_key, url): async def _check_key_revocation(self, public_key: str, url: str) -> None:
""" """
Checks whether public_key has been revoked. Checks whether public_key has been revoked.
Args: Args:
public_key (str): base-64 encoded public key. public_key: base-64 encoded public key.
url (str): Key revocation URL. url: Key revocation URL.
Raises: Raises:
AuthError: if they key has been revoked. AuthError: if they key has been revoked.

View file

@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
# Ratelimiters for `/requestToken` endpoints. # Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter( self._3pid_validation_ratelimiter_ip = Ratelimiter(
store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
) )
self._3pid_validation_ratelimiter_address = Ratelimiter( self._3pid_validation_ratelimiter_address = Ratelimiter(
store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
) )
def ratelimit_request_token_requests( async def ratelimit_request_token_requests(
self, self,
request: SynapseRequest, request: SynapseRequest,
medium: str, medium: str,
@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
address: The actual threepid ID, e.g. the phone number or email address address: The actual threepid ID, e.g. the phone number or email address
""" """
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) await self._3pid_validation_ratelimiter_ip.ratelimit(
self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) None, (medium, request.getClientIP())
)
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
)
async def threepid_from_creds( async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str] self, id_server: str, creds: Dict[str, str]

View file

@ -137,7 +137,7 @@ class MessageHandler:
self, self,
user_id: str, user_id: str,
room_id: str, room_id: str,
state_filter: StateFilter = StateFilter.all(), state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None, at_token: Optional[StreamToken] = None,
is_guest: bool = False, is_guest: bool = False,
) -> List[dict]: ) -> List[dict]:
@ -164,6 +164,8 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view AuthError (403) if the user doesn't have permission to view
members of this room. members of this room.
""" """
state_filter = state_filter or StateFilter.all()
if at_token: if at_token:
# FIXME this claims to get the state at a stream position, but # FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore # get_recent_events_for_room operates by topo ordering. This therefore
@ -385,7 +387,7 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types self.room_invite_state_types = self.hs.config.api.room_prejoin_state
self.membership_types_to_include_profile_data_in = ( self.membership_types_to_include_profile_data_in = (
{Membership.JOIN, Membership.INVITE} {Membership.JOIN, Membership.INVITE}
@ -876,7 +878,7 @@ class EventCreationHandler:
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,
extra_users: List[UserID] = [], extra_users: Optional[List[UserID]] = None,
ignore_shadow_ban: bool = False, ignore_shadow_ban: bool = False,
) -> EventBase: ) -> EventBase:
"""Processes a new event. """Processes a new event.
@ -904,6 +906,7 @@ class EventCreationHandler:
Raises: Raises:
ShadowBanError if the requester has been shadow-banned. ShadowBanError if the requester has been shadow-banned.
""" """
extra_users = extra_users or []
# we don't apply shadow-banning to membership events here. Invites are blocked # we don't apply shadow-banning to membership events here. Invites are blocked
# higher up the stack, and we allow shadow-banned users to send join and leave # higher up the stack, and we allow shadow-banned users to send join and leave
@ -1073,7 +1076,7 @@ class EventCreationHandler:
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,
extra_users: List[UserID] = [], extra_users: Optional[List[UserID]] = None,
) -> EventBase: ) -> EventBase:
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth. calculated the push actions for the event, and checked auth.
@ -1085,6 +1088,8 @@ class EventCreationHandler:
it was de-duplicated (e.g. because we had already persisted an it was de-duplicated (e.g. because we had already persisted an
event with the same transaction ID.) event with the same transaction ID.)
""" """
extra_users = extra_users or []
assert self.storage.persistence is not None assert self.storage.persistence is not None
assert self._events_shard_config.should_handle( assert self._events_shard_config.should_handle(
self._instance_name, event.room_id self._instance_name, event.room_id

View file

@ -25,7 +25,17 @@ The methods that define policy are:
import abc import abc
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple from typing import (
TYPE_CHECKING,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import ContextManager from typing_extensions import ContextManager
@ -34,6 +44,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
@ -42,7 +53,7 @@ from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
""" """
stream_id, max_token = await self.store.update_presence(states) stream_id, max_token = await self.store.update_presence(states)
parties = await get_interested_parties(self.store, states) parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties room_ids_to_states, users_to_states = parties
self.notifier.on_new_event( self.notifier.on_new_event(
@ -1041,7 +1053,12 @@ class PresenceEventSource:
# #
# Presence -> Notifier -> PresenceEventSource -> Presence # Presence -> Notifier -> PresenceEventSource -> Presence
# #
# Same with get_module_api, get_presence_router
#
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
self.get_presence_handler = hs.get_presence_handler self.get_presence_handler = hs.get_presence_handler
self.get_module_api = hs.get_module_api
self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -1054,8 +1071,8 @@ class PresenceEventSource:
room_ids=None, room_ids=None,
include_offline=True, include_offline=True,
explicit_room_id=None, explicit_room_id=None,
**kwargs **kwargs,
): ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms. # 2. Get the list of user in the rooms.
@ -1068,7 +1085,17 @@ class PresenceEventSource:
# We don't try and limit the presence updates by the current token, as # We don't try and limit the presence updates by the current token, as
# sending down the rare duplicate is not a concern. # sending down the rare duplicate is not a concern.
user_id = user.to_string()
stream_change_cache = self.store.presence_stream_cache
with Measure(self.clock, "presence.get_new_events"): with Measure(self.clock, "presence.get_new_events"):
if user_id in self.get_module_api()._send_full_presence_to_local_users:
# This user has been specified by a module to receive all current, online
# user presence. Removing from_key and setting include_offline to false
# will do effectively this.
from_key = None
include_offline = False
if from_key is not None: if from_key is not None:
from_key = int(from_key) from_key = int(from_key)
@ -1091,59 +1118,209 @@ class PresenceEventSource:
# doesn't return. C.f. #5503. # doesn't return. C.f. #5503.
return [], max_token return [], max_token
presence = self.get_presence_handler() # Figure out which other users this user should receive updates for
stream_change_cache = self.store.presence_stream_cache
users_interested_in = await self._get_interested_in(user, explicit_room_id) users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set() # type: Collection[str] # We have a set of users that we're interested in the presence of. We want to
changed = None # cross-reference that with the users that have actually changed their presence.
# Check whether this user should see all user updates
if users_interested_in == PresenceRouter.ALL_USERS:
# Provide presence state for all users
presence_updates = await self._filter_all_presence_updates_for_user(
user_id, include_offline, from_key
)
# Remove the user from the list of users to receive all presence
if user_id in self.get_module_api()._send_full_presence_to_local_users:
self.get_module_api()._send_full_presence_to_local_users.remove(
user_id
)
return presence_updates, max_token
# Make mypy happy. users_interested_in should now be a set
assert not isinstance(users_interested_in, str)
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users = (
set()
) # type: Union[Set[str], FrozenSet[str]]
if from_key: if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key) # First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
if changed is not None and len(changed) < 500: # Cross-reference users we're interested in with those that have had updates.
assert isinstance(user_ids_changed, set) # Use a slightly-optimised method for processing smaller sets of updates.
if updated_users is not None and len(updated_users) < 500:
# For small deltas, its quicker to get all changes and then # For small deltas, it's quicker to get all changes and then
# work out if we share a room or they're in our presence list # cross-reference with the users we're interested in
get_updates_counter.labels("stream").inc() get_updates_counter.labels("stream").inc()
for other_user_id in changed: for other_user_id in updated_users:
if other_user_id in users_interested_in: if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id) # mypy thinks this variable could be a FrozenSet as it's possibly set
# to one in the `get_entities_changed` call below, and `add()` is not
# method on a FrozenSet. That doesn't affect us here though, as
# `interested_and_updated_users` is clearly a set() above.
interested_and_updated_users.add(other_user_id) # type: ignore
else: else:
# Too many possible updates. Find all users we can see and check # Too many possible updates. Find all users we can see and check
# if any of them have changed. # if any of them have changed.
get_updates_counter.labels("full").inc() get_updates_counter.labels("full").inc()
if from_key: interested_and_updated_users = (
user_ids_changed = stream_change_cache.get_entities_changed( stream_change_cache.get_entities_changed(
users_interested_in, from_key users_interested_in, from_key
) )
else:
user_ids_changed = users_interested_in
updates = await presence.current_state_for_users(user_ids_changed)
if include_offline:
return (list(updates.values()), max_token)
else:
return (
[s for s in updates.values() if s.state != PresenceState.OFFLINE],
max_token,
) )
else:
# No from_key has been specified. Return the presence for all users
# this user is interested in
interested_and_updated_users = users_interested_in
# Retrieve the current presence state for each user
users_to_state = await self.get_presence_handler().current_state_for_users(
interested_and_updated_users
)
presence_updates = list(users_to_state.values())
# Remove the user from the list of users to receive all presence
if user_id in self.get_module_api()._send_full_presence_to_local_users:
self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
if not include_offline:
# Filter out offline presence states
presence_updates = self._filter_offline_presence_state(presence_updates)
return presence_updates, max_token
async def _filter_all_presence_updates_for_user(
self,
user_id: str,
include_offline: bool,
from_key: Optional[int] = None,
) -> List[UserPresenceState]:
"""
Computes the presence updates a user should receive.
First pulls presence updates from the database. Then consults PresenceRouter
for whether any updates should be excluded by user ID.
Args:
user_id: The User ID of the user to compute presence updates for.
include_offline: Whether to include offline presence states from the results.
from_key: The minimum stream ID of updates to pull from the database
before filtering.
Returns:
A list of presence states for the given user to receive.
"""
if from_key:
# Only return updates since the last sync
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
from_key
)
if not updated_users:
updated_users = []
# Get the actual presence update for each change
users_to_state = await self.get_presence_handler().current_state_for_users(
updated_users
)
presence_updates = list(users_to_state.values())
if not include_offline:
# Filter out offline states
presence_updates = self._filter_offline_presence_state(presence_updates)
else:
users_to_state = await self.store.get_presence_for_all_users(
include_offline=include_offline
)
presence_updates = list(users_to_state.values())
# TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
# module for information on a number of users when we then only take the info
# for a single user
# Filter through the presence router
users_to_state_set = await self.get_presence_router().get_users_for_states(
presence_updates
)
# We only want the mapping for the syncing user
presence_updates = list(users_to_state_set[user_id])
# Return presence information for all users
return presence_updates
def _filter_offline_presence_state(
self, presence_updates: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
"""Given an iterable containing user presence updates, return a list with any offline
presence states removed.
Args:
presence_updates: Presence states to filter
Returns:
A new list with any offline presence states removed.
"""
return [
update
for update in presence_updates
if update.state != PresenceState.OFFLINE
]
def get_current_key(self): def get_current_key(self):
return self.store.get_current_presence_token() return self.store.get_current_presence_token()
@cached(num_args=2, cache_context=True) @cached(num_args=2, cache_context=True)
async def _get_interested_in(self, user, explicit_room_id, cache_context): async def _get_interested_in(
self,
user: UserID,
explicit_room_id: Optional[str] = None,
cache_context: Optional[_CacheContext] = None,
) -> Union[Set[str], str]:
"""Returns the set of users that the given user should see presence """Returns the set of users that the given user should see presence
updates for updates for.
Args:
user: The user to retrieve presence updates for.
explicit_room_id: The users that are in the room will be returned.
Returns:
A set of user IDs to return presence updates for, or "ALL" to return all
known updates.
""" """
user_id = user.to_string() user_id = user.to_string()
users_interested_in = set() users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence users_interested_in.add(user_id) # So that we receive our own presence
# cache_context isn't likely to ever be None due to the @cached decorator,
# but we can't have a non-optional argument after the optional argument
# explicit_room_id either. Assert cache_context is not None so we can use it
# without mypy complaining.
assert cache_context
# Check with the presence router whether we should poll additional users for
# their presence information
additional_users = await self.get_presence_router().get_interested_users(
user.to_string()
)
if additional_users == PresenceRouter.ALL_USERS:
# If the module requested that this user see the presence updates of *all*
# users, then simply return that instead of calculating what rooms this
# user shares
return PresenceRouter.ALL_USERS
# Add the additional users from the router
users_interested_in.update(additional_users)
# Find the users who share a room with this user
users_who_share_room = await self.store.get_users_who_share_room_with_user( users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate user_id, on_invalidate=cache_context.invalidate
) )
@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
async def get_interested_parties( async def get_interested_parties(
store: DataStore, states: List[UserPresenceState] store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
"""Given a list of states return which entities (rooms, users) """Given a list of states return which entities (rooms, users)
are interested in the given states. are interested in the given states.
Args: Args:
store store: The homeserver's data store.
states presence_router: A module for augmenting the destinations for presence updates.
states: A list of incoming user presence updates.
Returns: Returns:
A 2-tuple of `(room_ids_to_states, users_to_states)`, A 2-tuple of `(room_ids_to_states, users_to_states)`,
@ -1337,11 +1515,22 @@ async def get_interested_parties(
# Always notify self # Always notify self
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
# Ask a presence routing module for any additional parties if one
# is loaded.
router_users_to_states = await presence_router.get_users_for_states(states)
# Update the dictionaries with additional destinations and state to send
for user_id, user_states in router_users_to_states.items():
users_to_states.setdefault(user_id, []).extend(user_states)
return room_ids_to_states, users_to_states return room_ids_to_states, users_to_states
async def get_interested_remotes( async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler store: DataStore,
presence_router: PresenceRouter,
states: List[UserPresenceState],
state_handler: StateHandler,
) -> List[Tuple[Collection[str], List[UserPresenceState]]]: ) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers """Given a list of presence states figure out which remote servers
should be sent which. should be sent which.
@ -1349,9 +1538,10 @@ async def get_interested_remotes(
All the presence states should be for local users only. All the presence states should be for local users only.
Args: Args:
store store: The homeserver's data store.
states presence_router: A module for augmenting the destinations for presence updates.
state_handler states: A list of incoming user presence updates.
state_handler:
Returns: Returns:
A list of 2-tuples of destinations and states, where for A list of 2-tuples of destinations and states, where for
@ -1363,7 +1553,9 @@ async def get_interested_remotes(
# First we look up the rooms each user is in (as well as any explicit # First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote # subscriptions), then for each distinct room we look up the remote
# hosts in those rooms. # hosts in those rooms.
room_ids_to_states, users_to_states = await get_interested_parties(store, states) room_ids_to_states, users_to_states = await get_interested_parties(
store, presence_router, states
)
for room_id, states in room_ids_to_states.items(): for room_id, states in room_ids_to_states.items():
hosts = await state_handler.get_current_hosts_in_room(room_id) hosts = await state_handler.get_current_hosts_in_room(room_id)

View file

@ -174,7 +174,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None, user_type: Optional[str] = None,
default_display_name: Optional[str] = None, default_display_name: Optional[str] = None,
address: Optional[str] = None, address: Optional[str] = None,
bind_emails: Iterable[str] = [], bind_emails: Optional[Iterable[str]] = None,
by_admin: bool = False, by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None, user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
@ -209,7 +209,9 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
SynapseError if there was a problem registering. SynapseError if there was a problem registering.
""" """
self.check_registration_ratelimit(address) bind_emails = bind_emails or []
await self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam( result = await self.spam_checker.check_registration_for_spam(
threepid, threepid,
@ -590,7 +592,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
def check_registration_ratelimit(self, address: Optional[str]) -> None: async def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit """A simple helper method to check whether the registration rate limit has been hit
for a given IP address for a given IP address
@ -604,7 +606,7 @@ class RegistrationHandler(BaseHandler):
if not address: if not address:
return return
self.ratelimiter.ratelimit(address) await self.ratelimiter.ratelimit(None, address)
async def register_with_store( async def register_with_store(
self, self,

View file

@ -20,7 +20,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types from synapse import types
from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@ -75,22 +76,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.allow_per_room_profiles = self.config.allow_per_room_profiles self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._join_rate_limiter_local = Ratelimiter( self._join_rate_limiter_local = Ratelimiter(
store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
) )
self._join_rate_limiter_remote = Ratelimiter( self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
) )
self._invites_per_room_limiter = Ratelimiter( self._invites_per_room_limiter = Ratelimiter(
store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
) )
self._invites_per_user_limiter = Ratelimiter( self._invites_per_user_limiter = Ratelimiter(
store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
@ -159,15 +164,76 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def forget(self, user: UserID, room_id: str) -> None: async def forget(self, user: UserID, room_id: str) -> None:
raise NotImplementedError() raise NotImplementedError()
def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): async def ratelimit_invite(
self,
requester: Optional[Requester],
room_id: Optional[str],
invitee_user_id: str,
):
"""Ratelimit invites by room and by target user. """Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user. If room ID is missing then we just rate limit by target user.
""" """
if room_id: if room_id:
self._invites_per_room_limiter.ratelimit(room_id) await self._invites_per_room_limiter.ratelimit(requester, room_id)
self._invites_per_user_limiter.ratelimit(invitee_user_id) await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
async def _can_join_without_invite(
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
) -> bool:
"""
Check whether a user can join a room without an invite.
When joining a room with restricted joined rules (as defined in MSC3083),
the membership of spaces must be checked during join.
Args:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
Returns:
True if the user can join the room, false otherwise.
"""
# This only applies to room versions which support the new join rule.
if not room_version.msc3083_join_rules:
return True
# If there's no join rule, then it defaults to public (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return True
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self.store.get_event(join_rules_event_id)
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
return True
# If allowed is of the wrong form, then only allow invited users.
allowed_spaces = join_rules_event.content.get("allow", [])
if not isinstance(allowed_spaces, list):
return False
# Get the list of joined rooms and see if there's an overlap.
joined_rooms = await self.store.get_rooms_for_user(user_id)
# Pull out the other room IDs, invalid data gets filtered.
for space in allowed_spaces:
if not isinstance(space, dict):
continue
space_id = space.get("space")
if not isinstance(space_id, str):
continue
# The user was joined to one of the spaces specified, they can join
# this room!
if space_id in joined_rooms:
return True
# The user was not in any of the required spaces.
return False
async def _local_membership_update( async def _local_membership_update(
self, self,
@ -226,9 +292,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
newly_joined = True newly_joined = True
user_is_invited = False
if prev_member_event_id: if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
user_is_invited = prev_member_event.membership == Membership.INVITE
# If the member is not already in the room and is not accepting an invite,
# check if they should be allowed access via membership in a space.
if (
newly_joined
and not user_is_invited
and not await self._can_join_without_invite(
prev_state_ids, event.room_version, user_id
)
):
raise AuthError(
403,
"You do not belong to any of the required spaces to join this room.",
)
# Only rate-limit if the user actually joined the room, otherwise we'll end # Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates. # up blocking profile updates.
@ -237,7 +319,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
( (
allowed, allowed,
time_allowed, time_allowed,
) = self._join_rate_limiter_local.can_requester_do_action(requester) ) = await self._join_rate_limiter_local.can_do_action(requester)
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(
@ -421,9 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if effective_membership_state == Membership.INVITE: if effective_membership_state == Membership.INVITE:
target_id = target.to_string() target_id = target.to_string()
if ratelimit: if ratelimit:
# Don't ratelimit application services. await self.ratelimit_invite(requester, room_id, target_id)
if not requester.app_service or requester.app_service.is_rate_limited():
self.ratelimit_invite(room_id, target_id)
# block any attempts to invite the server notices mxid # block any attempts to invite the server notices mxid
if target_id == self._server_notices_mxid: if target_id == self._server_notices_mxid:
@ -534,7 +614,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
( (
allowed, allowed,
time_allowed, time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action( ) = await self._join_rate_limiter_remote.can_do_action(
requester, requester,
) )

View file

@ -24,6 +24,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import current_context from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -251,13 +252,13 @@ class SyncHandler:
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache", "lazy_loaded_members_cache",
self.clock, self.clock,
max_len=0, max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
@ -340,7 +341,14 @@ class SyncHandler:
full_state: bool = False, full_state: bool = False,
) -> SyncResult: ) -> SyncResult:
"""Get the sync for client needed to match what the server has now.""" """Get the sync for client needed to match what the server has now."""
return await self.generate_sync_result(sync_config, since_token, full_state) with start_active_span("current_sync_for_user"):
log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state
)
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
async def push_rules_for_user(self, user: UserID) -> JsonDict: async def push_rules_for_user(self, user: UserID) -> JsonDict:
user_id = user.to_string() user_id = user.to_string()
@ -540,7 +548,7 @@ class SyncHandler:
) )
async def get_state_after_event( async def get_state_after_event(
self, event: EventBase, state_filter: StateFilter = StateFilter.all() self, event: EventBase, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the room state after the given event Get the room state after the given event
@ -550,7 +558,7 @@ class SyncHandler:
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
""" """
state_ids = await self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter or StateFilter.all()
) )
if event.is_state(): if event.is_state():
state_ids = dict(state_ids) state_ids = dict(state_ids)
@ -561,7 +569,7 @@ class SyncHandler:
self, self,
room_id: str, room_id: str,
stream_position: StreamToken, stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(), state_filter: Optional[StateFilter] = None,
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the room state at a particular stream position """Get the room state at a particular stream position
@ -581,7 +589,7 @@ class SyncHandler:
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = await self.get_state_after_event( state = await self.get_state_after_event(
last_event, state_filter=state_filter last_event, state_filter=state_filter or StateFilter.all()
) )
else: else:
@ -725,8 +733,10 @@ class SyncHandler:
def get_lazy_loaded_members_cache( def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]] self, cache_key: Tuple[str, Optional[str]]
) -> LruCache: ) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get(cache_key) cache = self.lazy_loaded_members_cache.get(
cache_key
) # type: Optional[LruCache[str, str]]
if cache is None: if cache is None:
logger.debug("creating LruCache for %r", cache_key) logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@ -963,6 +973,7 @@ class SyncHandler:
# to query up to a given point. # to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder` # Always use the `now_token` in `SyncResultBuilder`
now_token = self.event_sources.get_current_token() now_token = self.event_sources.get_current_token()
log_kv({"now_token": now_token})
logger.debug( logger.debug(
"Calculating sync response for %r between %s and %s", "Calculating sync response for %r between %s and %s",
@ -1224,6 +1235,13 @@ class SyncHandler:
user_id, device_id, since_stream_id, now_token.to_device_key user_id, device_id, since_stream_id, now_token.to_device_key
) )
for message in messages:
# We pop here as we shouldn't be sending the message ID down
# `/sync`
message_id = message.pop("message_id", None)
if message_id:
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
logger.debug( logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)", "Returning %d to-device messages between %d and %d (current token: %d)",
len(messages), len(messages),

View file

@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -86,6 +89,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {} self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer = WheelTimer(bucket_size=5000)
@wrap_as_background_process("typing._handle_timeouts")
def _handle_timeouts(self) -> None: def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts") logger.debug("Checking for typing timeouts")

View file

@ -297,7 +297,7 @@ class SimpleHttpClient:
def __init__( def __init__(
self, self,
hs: "HomeServer", hs: "HomeServer",
treq_args: Dict[str, Any] = {}, treq_args: Optional[Dict[str, Any]] = None,
ip_whitelist: Optional[IPSet] = None, ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None,
user_agent: Optional[str] = None, user_agent: Optional[str] = None,
@ -318,7 +318,7 @@ class SimpleHttpClient:
self._ip_whitelist = ip_whitelist self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist self._ip_blacklist = ip_blacklist
self._extra_treq_args = treq_args self._extra_treq_args = treq_args or {}
self.user_agent = user_agent or hs.version_string self.user_agent = user_agent or hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -591,7 +591,7 @@ class SimpleHttpClient:
uri: str, uri: str,
json_body: Any, json_body: Any,
args: Optional[QueryParams] = None, args: Optional[QueryParams] = None,
headers: RawHeaders = None, headers: Optional[RawHeaders] = None,
) -> Any: ) -> Any:
"""Puts some json to the given URI. """Puts some json to the given URI.

View file

@ -272,7 +272,7 @@ class MatrixFederationHttpClient:
self, self,
request: MatrixFederationRequest, request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
**send_request_args **send_request_args,
) -> IResponse: ) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request """Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a upon receiving a combination of a 400 HTTP response code and a

View file

@ -27,7 +27,7 @@ from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent, IPolicyForHTTPS
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@ -88,12 +88,14 @@ class ProxyAgent(_AgentBase):
self, self,
reactor, reactor,
proxy_reactor=None, proxy_reactor=None,
contextFactory=BrowserLikePolicyForHTTPS(), contextFactory: Optional[IPolicyForHTTPS] = None,
connectTimeout=None, connectTimeout=None,
bindAddress=None, bindAddress=None,
pool=None, pool=None,
use_proxy=False, use_proxy=False,
): ):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
_AgentBase.__init__(self, reactor, pool) _AgentBase.__init__(self, reactor, pool)
if proxy_reactor is None: if proxy_reactor is None:

View file

@ -14,7 +14,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Type, Union from typing import Optional, Tuple, Type, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
@ -26,7 +26,11 @@ from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import (
ContextRequest,
LoggingContext,
PreserveLoggingContext,
)
from synapse.types import Requester from synapse.types import Requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +67,7 @@ class SynapseRequest(Request):
# The requester, if authenticated. For federation requests this is the # The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object. # server name, for client requests this is the Requester object.
self.requester = None # type: Optional[Union[Requester, str]] self._requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method. # we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext] self.logcontext = None # type: Optional[LoggingContext]
@ -93,6 +97,31 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
@property
def requester(self) -> Optional[Union[Requester, str]]:
return self._requester
@requester.setter
def requester(self, value: Union[Requester, str]) -> None:
# Store the requester, and update some properties based on it.
# This should only be called once.
assert self._requester is None
self._requester = value
# A logging context should exist by now (and have a ContextRequest).
assert self.logcontext is not None
assert self.logcontext.request is not None
(
requester,
authenticated_entity,
) = self.get_authenticated_entity()
self.logcontext.request.requester = requester
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self): def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
@ -126,13 +155,60 @@ class SynapseRequest(Request):
return self.method.decode("ascii") return self.method.decode("ascii")
return method return method
def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
"""
Get the "authenticated" entity of the request, which might be the user
performing the action, or a user being puppeted by a server admin.
Returns:
A tuple:
The first item is a string representing the user making the request.
The second item is a string or None representing the user who
authenticated when making this request. See
Requester.authenticated_entity.
"""
# Convert the requester into a string that we can log
if isinstance(self._requester, str):
return self._requester, None
elif isinstance(self._requester, Requester):
requester = self._requester.user.to_string()
authenticated_entity = self._requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both.
if self._requester.user.to_string() != authenticated_entity:
return requester, authenticated_entity
return requester, None
elif self._requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
return repr(self._requester), None # type: ignore[unreachable]
return None, None
def render(self, resrc): def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our # this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource. # case the Resource in question will normally be a JsonResource.
# create a LogContext for this request # create a LogContext for this request
request_id = self.get_request_id() request_id = self.get_request_id()
self.logcontext = LoggingContext(request_id, request=request_id) self.logcontext = LoggingContext(
request_id,
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
site_tag=self.site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
method=self.get_method(),
url=self.get_redacted_uri(),
protocol=self.clientproto.decode("ascii", errors="replace"),
user_agent=get_request_user_agent(self),
),
)
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)
@ -277,25 +353,6 @@ class SynapseRequest(Request):
# to the client (nb may be negative) # to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time response_send_time = self.finish_time - self._processing_finished_time
# Convert the requester into a string that we can log
authenticated_entity = None
if isinstance(self.requester, str):
authenticated_entity = self.requester
elif isinstance(self.requester, Requester):
authenticated_entity = self.requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity,
self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
user_agent = get_request_user_agent(self, "-") user_agent = get_request_user_agent(self, "-")
code = str(self.code) code = str(self.code)
@ -305,6 +362,13 @@ class SynapseRequest(Request):
code += "!" code += "!"
log_level = logging.INFO if self._should_log_request() else logging.DEBUG log_level = logging.INFO if self._should_log_request() else logging.DEBUG
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
requester, authenticated_entity = self.get_authenticated_entity()
if authenticated_entity:
requester = "{}.{}".format(authenticated_entity, requester)
self.site.access_logger.log( self.site.access_logger.log(
log_level, log_level,
"%s - %s - {%s}" "%s - %s - {%s}"
@ -312,7 +376,7 @@ class SynapseRequest(Request):
' %sB %s "%s %s %s" "%s" [%d dbevts]', ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
authenticated_entity, requester,
processing_time, processing_time,
response_send_time, response_send_time,
usage.ru_utime, usage.ru_utime,
@ -433,7 +497,7 @@ class SynapseSite(Site):
resource, resource,
server_version_string, server_version_string,
*args, *args,
**kwargs **kwargs,
): ):
Site.__init__(self, resource, *args, **kwargs) Site.__init__(self, resource, *args, **kwargs)

View file

@ -22,7 +22,6 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import inspect import inspect
import logging import logging
import threading import threading
@ -30,6 +29,7 @@ import types
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer, threads from twisted.internet import defer, threads
@ -181,6 +181,29 @@ class ContextResourceUsage:
return res return res
@attr.s(slots=True)
class ContextRequest:
"""
A bundle of attributes from the SynapseRequest object.
This exists to:
* Avoid a cycle between LoggingContext and SynapseRequest.
* Be a single variable that can be passed from parent LoggingContexts to
their children.
"""
request_id = attr.ib(type=str)
ip_address = attr.ib(type=str)
site_tag = attr.ib(type=str)
requester = attr.ib(type=Optional[str])
authenticated_entity = attr.ib(type=Optional[str])
method = attr.ib(type=str)
url = attr.ib(type=str)
protocol = attr.ib(type=str)
user_agent = attr.ib(type=str)
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
@ -256,7 +279,7 @@ class LoggingContext:
self, self,
name: Optional[str] = None, name: Optional[str] = None,
parent_context: "Optional[LoggingContext]" = None, parent_context: "Optional[LoggingContext]" = None,
request: Optional[str] = None, request: Optional[ContextRequest] = None,
) -> None: ) -> None:
self.previous_context = current_context() self.previous_context = current_context()
self.name = name self.name = name
@ -281,7 +304,11 @@ class LoggingContext:
self.parent_context = parent_context self.parent_context = parent_context
if self.parent_context is not None: if self.parent_context is not None:
self.parent_context.copy_to(self) # we track the current request_id
self.request = self.parent_context.request
# we also track the current scope:
self.scope = self.parent_context.scope
if request is not None: if request is not None:
# the request param overrides the request from the parent context # the request param overrides the request from the parent context
@ -289,7 +316,7 @@ class LoggingContext:
def __str__(self) -> str: def __str__(self) -> str:
if self.request: if self.request:
return str(self.request) return self.request.request_id
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod @classmethod
@ -556,8 +583,23 @@ class LoggingContextFilter(logging.Filter):
# we end up in a death spiral of infinite loops, so let's check, for # we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake. # robustness' sake.
if context is not None: if context is not None:
# Logging is interested in the request. # Logging is interested in the request ID. Note that for backwards
record.request = context.request # type: ignore # compatibility this is stored as the "request" on the record.
record.request = str(context) # type: ignore
# Add some data from the HTTP request.
request = context.request
if request is None:
return True
record.ip_address = request.ip_address # type: ignore
record.site_tag = request.site_tag # type: ignore
record.requester = request.requester # type: ignore
record.authenticated_entity = request.authenticated_entity # type: ignore
record.method = request.method # type: ignore
record.url = request.url # type: ignore
record.protocol = request.protocol # type: ignore
record.user_agent = request.user_agent # type: ignore
return True return True
@ -630,8 +672,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
def nested_logging_context(suffix: str) -> LoggingContext: def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another. """Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's The nested logging context will have a 'name' made up of the parent context's
request, plus the given suffix. name, plus the given suffix.
CPU/db usage stats will be added to the parent context's on exit. CPU/db usage stats will be added to the parent context's on exit.
@ -641,7 +683,7 @@ def nested_logging_context(suffix: str) -> LoggingContext:
# ... do stuff # ... do stuff
Args: Args:
suffix: suffix to add to the parent context's 'request'. suffix: suffix to add to the parent context's 'name'.
Returns: Returns:
LoggingContext: new logging context. LoggingContext: new logging context.
@ -653,11 +695,17 @@ def nested_logging_context(suffix: str) -> LoggingContext:
) )
parent_context = None parent_context = None
prefix = "" prefix = ""
request = None
else: else:
assert isinstance(curr_context, LoggingContext) assert isinstance(curr_context, LoggingContext)
parent_context = curr_context parent_context = curr_context
prefix = str(parent_context.request) prefix = str(parent_context.name)
return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix) request = parent_context.request
return LoggingContext(
prefix + "-" + suffix,
parent_context=parent_context,
request=request,
)
def preserve_fn(f): def preserve_fn(f):

View file

@ -259,6 +259,14 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SynapseTags:
# The message ID of any to_device message processed
TO_DEVICE_MESSAGE_ID = "to_device.message_id"
# Whether the sync response has new data to be returned to the client.
SYNC_RESULT = "sync.new_data"
# Block everything by default # Block everything by default
# A regex which matches the server_names to expose traces for. # A regex which matches the server_names to expose traces for.
# None means 'block everything'. # None means 'block everything'.
@ -478,7 +486,7 @@ def start_active_span_from_request(
def start_active_span_from_edu( def start_active_span_from_edu(
edu_content, edu_content,
operation_name, operation_name,
references=[], references: Optional[list] = None,
tags=None, tags=None,
start_time=None, start_time=None,
ignore_active_span=False, ignore_active_span=False,
@ -493,6 +501,7 @@ def start_active_span_from_edu(
For the other args see opentracing.tracer For the other args see opentracing.tracer
""" """
references = references or []
if opentracing is None: if opentracing is None:
return noop_context_manager() return noop_context_manager()

View file

@ -214,7 +214,12 @@ class GaugeBucketCollector:
Prometheus, and optimise for that case. Prometheus, and optimise for that case.
""" """
__slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") __slots__ = (
"_name",
"_documentation",
"_bucket_bounds",
"_metric",
)
def __init__( def __init__(
self, self,
@ -242,10 +247,15 @@ class GaugeBucketCollector:
if self._bucket_bounds[-1] != float("inf"): if self._bucket_bounds[-1] != float("inf"):
self._bucket_bounds.append(float("inf")) self._bucket_bounds.append(float("inf"))
self._metric = self._values_to_metric([]) # We initially set this to None. We won't report metrics until
# this has been initialised after a successful data update
self._metric = None # type: Optional[GaugeHistogramMetricFamily]
registry.register(self) registry.register(self)
def collect(self): def collect(self):
# Don't report metrics unless we've already collected some data
if self._metric is not None:
yield self._metric yield self._metric
def update_data(self, values: Iterable[float]): def update_data(self, values: Iterable[float]):

View file

@ -16,7 +16,7 @@
import logging import logging
import threading import threading
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Set from typing import TYPE_CHECKING, Dict, Optional, Set, Union
from prometheus_client.core import REGISTRY, Counter, Gauge from prometheus_client.core import REGISTRY, Counter, Gauge
@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc() _background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc()
with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: with BackgroundProcessLoggingContext(desc, count) as context:
try: try:
ctx = noop_context_manager() ctx = noop_context_manager()
if bg_start_span: if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request}) ctx = start_active_span(desc, tags={"request_id": str(context)})
with ctx: with ctx:
return await maybe_awaitable(func(*args, **kwargs)) return await maybe_awaitable(func(*args, **kwargs))
except Exception: except Exception:
@ -242,13 +242,19 @@ class BackgroundProcessLoggingContext(LoggingContext):
processes. processes.
""" """
__slots__ = ["_proc"] __slots__ = ["_id", "_proc"]
def __init__(self, name: str, request: Optional[str] = None): def __init__(self, name: str, id: Optional[Union[int, str]] = None):
super().__init__(name, request=request) super().__init__(name)
self._id = id
self._proc = _BackgroundProcess(name, self) self._proc = _BackgroundProcess(name, self)
def __str__(self) -> str:
if self._id is not None:
return "%s-%s" % (self.name, self._id)
return "%s@%x" % (self.name, id(self))
def start(self, rusage: "Optional[resource._RUsage]"): def start(self, rusage: "Optional[resource._RUsage]"):
"""Log context has started running (again).""" """Log context has started running (again)."""

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -50,11 +50,20 @@ class ModuleApi:
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = auth_handler self._auth_handler = auth_handler
self._server_name = hs.hostname self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"]
# We expose these as properties below in order to attach a helpful docstring. # We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
self._public_room_list_manager = PublicRoomListManager(hs) self._public_room_list_manager = PublicRoomListManager(hs)
# The next time these users sync, they will receive the current presence
# state of all local users. Users are added by send_local_online_presence_to,
# and removed after a successful sync.
#
# We make this a private variable to deter modules from accessing it directly,
# though other classes in Synapse will still do so.
self._send_full_presence_to_local_users = set()
@property @property
def http_client(self): def http_client(self):
"""Allows making outbound HTTP requests to remote resources. """Allows making outbound HTTP requests to remote resources.
@ -118,7 +127,7 @@ class ModuleApi:
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]): def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
"""Registers a new user with given localpart and optional displayname, emails. """Registers a new user with given localpart and optional displayname, emails.
Also returns an access token for the new user. Also returns an access token for the new user.
@ -138,11 +147,13 @@ class ModuleApi:
logger.warning( logger.warning(
"Using deprecated ModuleApi.register which creates a dummy user device." "Using deprecated ModuleApi.register which creates a dummy user device."
) )
user_id = yield self.register_user(localpart, displayname, emails) user_id = yield self.register_user(localpart, displayname, emails or [])
_, access_token = yield self.register_device(user_id) _, access_token = yield self.register_device(user_id)
return user_id, access_token return user_id, access_token
def register_user(self, localpart, displayname=None, emails=[]): def register_user(
self, localpart, displayname=None, emails: Optional[List[str]] = None
):
"""Registers a new user with given localpart and optional displayname, emails. """Registers a new user with given localpart and optional displayname, emails.
Args: Args:
@ -161,7 +172,7 @@ class ModuleApi:
self._hs.get_registration_handler().register_user( self._hs.get_registration_handler().register_user(
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
bind_emails=emails, bind_emails=emails or [],
) )
) )
@ -385,6 +396,47 @@ class ModuleApi:
return event return event
async def send_local_online_presence_to(self, users: Iterable[str]) -> None:
"""
Forces the equivalent of a presence initial_sync for a set of local or remote
users. The users will receive presence for all currently online users that they
are considered interested in.
Updates to remote users will be sent immediately, whereas local users will receive
them on their next sync attempt.
Note that this method can only be run on the main or federation_sender worker
processes.
"""
if not self._hs.should_send_federation():
raise Exception(
"send_local_online_presence_to can only be run "
"on processes that send federation",
)
for user in users:
if self._hs.is_mine_id(user):
# Modify SyncHandler._generate_sync_entry_for_presence to call
# presence_source.get_new_events with an empty `from_key` if
# that user's ID were in a list modified by ModuleApi somewhere.
# That user would then get all presence state on next incremental sync.
# Force a presence initial_sync for this user next time
self._send_full_presence_to_local_users.add(user)
else:
# Retrieve presence state for currently online users that this user
# is considered interested in
presence_events, _ = await self._presence_stream.get_new_events(
UserID.from_string(user), from_key=None, include_offline=False
)
# Send to remote destinations
await make_deferred_yieldable(
# We pull the federation sender here as we can only do so on workers
# that support sending presence
self._hs.get_federation_sender().send_presence(presence_events)
)
class PublicRoomListManager: class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room """Contains methods for adding to, removing from and querying whether a room

View file

@ -39,6 +39,7 @@ from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import log_kv, start_active_span
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
@ -136,6 +137,15 @@ class _NotifierUserStream:
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
log_kv(
{
"notify": self.user_id,
"stream": stream_key,
"stream_id": stream_id,
"listeners": self.count_listeners(),
}
)
users_woken_by_stream_counter.labels(stream_key).inc() users_woken_by_stream_counter.labels(stream_key).inc()
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -266,7 +276,7 @@ class Notifier:
event: EventBase, event: EventBase,
event_pos: PersistedEventPosition, event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [], extra_users: Optional[Collection[UserID]] = None,
): ):
"""Unwraps event and calls `on_new_room_event_args`.""" """Unwraps event and calls `on_new_room_event_args`."""
self.on_new_room_event_args( self.on_new_room_event_args(
@ -276,7 +286,7 @@ class Notifier:
state_key=event.get("state_key"), state_key=event.get("state_key"),
membership=event.content.get("membership"), membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token, max_room_stream_token=max_room_stream_token,
extra_users=extra_users, extra_users=extra_users or [],
) )
def on_new_room_event_args( def on_new_room_event_args(
@ -287,7 +297,7 @@ class Notifier:
membership: Optional[str], membership: Optional[str],
event_pos: PersistedEventPosition, event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [], extra_users: Optional[Collection[UserID]] = None,
): ):
"""Used by handlers to inform the notifier something has happened """Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
@ -303,7 +313,7 @@ class Notifier:
self.pending_new_room_events.append( self.pending_new_room_events.append(
_PendingRoomEventEntry( _PendingRoomEventEntry(
event_pos=event_pos, event_pos=event_pos,
extra_users=extra_users, extra_users=extra_users or [],
room_id=room_id, room_id=room_id,
type=event_type, type=event_type,
state_key=state_key, state_key=state_key,
@ -372,14 +382,14 @@ class Notifier:
self, self,
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]] = [], users: Optional[Collection[Union[str, UserID]]] = None,
): ):
try: try:
stream_token = None stream_token = None
if isinstance(new_token, int): if isinstance(new_token, int):
stream_token = new_token stream_token = new_token
self.appservice_handler.notify_interested_services_ephemeral( self.appservice_handler.notify_interested_services_ephemeral(
stream_key, stream_token, users stream_key, stream_token, users or []
) )
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
@ -394,16 +404,26 @@ class Notifier:
self, self,
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]] = [], users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Collection[str] = [], rooms: Optional[Collection[str]] = None,
): ):
"""Used to inform listeners that something has happened event wise. """Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
users = users or []
rooms = rooms or []
with Measure(self.clock, "on_new_event"): with Measure(self.clock, "on_new_event"):
user_streams = set() user_streams = set()
log_kv(
{
"waking_up_explicit_users": len(users),
"waking_up_explicit_rooms": len(rooms),
}
)
for user in users: for user in users:
user_stream = self.user_to_user_stream.get(str(user)) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None: if user_stream is not None:
@ -476,12 +496,34 @@ class Notifier:
(end_time - now) / 1000.0, (end_time - now) / 1000.0,
self.hs.get_reactor(), self.hs.get_reactor(),
) )
with start_active_span("wait_for_events.deferred"):
log_kv(
{
"wait_for_events": "sleep",
"token": prev_token,
}
)
with PreserveLoggingContext(): with PreserveLoggingContext():
await listener.deferred await listener.deferred
log_kv(
{
"wait_for_events": "woken",
"token": user_stream.current_token,
}
)
current_token = user_stream.current_token current_token = user_stream.current_token
result = await callback(prev_token, current_token) result = await callback(prev_token, current_token)
log_kv(
{
"wait_for_events": "result",
"result": bool(result),
}
)
if result: if result:
break break
@ -489,8 +531,10 @@ class Notifier:
# has happened between the old prev_token and the current_token # has happened between the old prev_token and the current_token
prev_token = current_token prev_token = current_token
except defer.TimeoutError: except defer.TimeoutError:
log_kv({"wait_for_events": "timeout"})
break break
except defer.CancelledError: except defer.CancelledError:
log_kv({"wait_for_events": "cancelled"})
break break
if result is None: if result is None:
@ -507,7 +551,7 @@ class Notifier:
pagination_config: PaginationConfig, pagination_config: PaginationConfig,
timeout: int, timeout: int,
is_guest: bool = False, is_guest: bool = False,
explicit_room_id: str = None, explicit_room_id: Optional[str] = None,
) -> EventStreamResult: ) -> EventStreamResult:
"""For the given user and rooms, return any new events for them. If """For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any

View file

@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.check_registration_ratelimit(content["address"])
await self.registration_handler.register_with_store( await self.registration_handler.register_with_store(
user_id=user_id, user_id=user_id,

View file

@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus. # background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id self._logging_context = BackgroundProcessLoggingContext(
self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) "replication-conn", self.conn_id
)
def connectionMade(self): def connectionMade(self):
logger.info("[%s] Connection established", self.id()) logger.info("[%s] Connection established", self.id())

View file

@ -60,7 +60,7 @@ class ConstantProperty(Generic[T, V]):
constant = attr.ib() # type: V constant = attr.ib() # type: V
def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V: def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
return self.constant return self.constant
def __set__(self, obj: Optional[T], value: V): def __set__(self, obj: Optional[T], value: V):

View file

@ -54,6 +54,7 @@ from synapse.rest.admin.users import (
AccountValidityRenewServlet, AccountValidityRenewServlet,
DeactivateAccountRestServlet, DeactivateAccountRestServlet,
PushersRestServlet, PushersRestServlet,
RateLimitRestServlet,
ResetPasswordRestServlet, ResetPasswordRestServlet,
SearchUsersRestServlet, SearchUsersRestServlet,
ShadowBanRestServlet, ShadowBanRestServlet,
@ -62,7 +63,6 @@ from synapse.rest.admin.users import (
UserMembershipRestServlet, UserMembershipRestServlet,
UserRegisterServlet, UserRegisterServlet,
UserRestServletV2, UserRestServletV2,
UsersRestServlet,
UsersRestServletV2, UsersRestServletV2,
UserTokenRestServlet, UserTokenRestServlet,
WhoisRestServlet, WhoisRestServlet,
@ -240,6 +240,7 @@ def register_servlets(hs, http_server):
ShadowBanRestServlet(hs).register(http_server) ShadowBanRestServlet(hs).register(http_server)
ForwardExtremitiesRestServlet(hs).register(http_server) ForwardExtremitiesRestServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server): def register_servlets_for_client_rest_resource(hs, http_server):
@ -248,7 +249,6 @@ def register_servlets_for_client_rest_resource(hs, http_server):
PurgeHistoryStatusRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)

View file

@ -36,6 +36,7 @@ from synapse.rest.admin._base import (
) )
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.storage.databases.main.media_repository import MediaSortOrder
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -44,29 +45,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
ret = await self.store.get_users()
return 200, ret
class UsersRestServletV2(RestServlet): class UsersRestServletV2(RestServlet):
PATTERNS = admin_patterns("/users$", "v2") PATTERNS = admin_patterns("/users$", "v2")
@ -117,8 +95,26 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True) guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False) deactivated = parse_boolean(request, "deactivated", default=False)
order_by = parse_string(
request,
"order_by",
default=UserSortOrder.NAME.value,
allowed_values=(
UserSortOrder.NAME.value,
UserSortOrder.DISPLAYNAME.value,
UserSortOrder.GUEST.value,
UserSortOrder.ADMIN.value,
UserSortOrder.DEACTIVATED.value,
UserSortOrder.USER_TYPE.value,
UserSortOrder.AVATAR_URL.value,
UserSortOrder.SHADOW_BANNED.value,
),
)
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
users, total = await self.store.get_users_paginate( users, total = await self.store.get_users_paginate(
start, limit, user_id, name, guests, deactivated start, limit, user_id, name, guests, deactivated, order_by, direction
) )
ret = {"users": users, "total": total} ret = {"users": users, "total": total}
if (start + limit) < total: if (start + limit) < total:
@ -985,3 +981,114 @@ class ShadowBanRestServlet(RestServlet):
await self.store.set_shadow_banned(UserID.from_string(user_id), True) await self.store.set_shadow_banned(UserID.from_string(user_id), True)
return 200, {} return 200, {}
class RateLimitRestServlet(RestServlet):
"""An admin API to override ratelimiting for an user.
Example:
POST /_synapse/admin/v1/users/@test:example.com/override_ratelimit
{
"messages_per_second": 0,
"burst_count": 0
}
200 OK
{
"messages_per_second": 0,
"burst_count": 0
}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Can only lookup local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
ratelimit = await self.store.get_ratelimit_for_user(user_id)
if ratelimit:
# convert `null` to `0` for consistency
# both values do the same in retelimit handler
ret = {
"messages_per_second": 0
if ratelimit.messages_per_second is None
else ratelimit.messages_per_second,
"burst_count": 0
if ratelimit.burst_count is None
else ratelimit.burst_count,
}
else:
ret = {}
return 200, ret
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be ratelimited")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
body = parse_json_object_from_request(request, allow_empty_body=True)
messages_per_second = body.get("messages_per_second", 0)
burst_count = body.get("burst_count", 0)
if not isinstance(messages_per_second, int) or messages_per_second < 0:
raise SynapseError(
400,
"%r parameter must be a positive int" % (messages_per_second,),
errcode=Codes.INVALID_PARAM,
)
if not isinstance(burst_count, int) or burst_count < 0:
raise SynapseError(
400,
"%r parameter must be a positive int" % (burst_count,),
errcode=Codes.INVALID_PARAM,
)
await self.store.set_ratelimit_for_user(
user_id, messages_per_second, burst_count
)
ratelimit = await self.store.get_ratelimit_for_user(user_id)
assert ratelimit is not None
ret = {
"messages_per_second": ratelimit.messages_per_second,
"burst_count": ratelimit.burst_count,
}
return 200, ret
async def on_DELETE(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be ratelimited")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
await self.store.delete_ratelimit_for_user(user_id)
return 200, {}

View file

@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
store=hs.get_datastore(),
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second, rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count, burst_count=self.hs.config.rc_login_address.burst_count,
) )
self._account_ratelimiter = Ratelimiter( self._account_ratelimiter = Ratelimiter(
store=hs.get_datastore(),
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second, rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count, burst_count=self.hs.config.rc_login_account.burst_count,
@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
appservice = self.auth.get_appservice_by_req(request) appservice = self.auth.get_appservice_by_req(request)
if appservice.is_rate_limited(): if appservice.is_rate_limited():
self._address_ratelimiter.ratelimit(request.getClientIP()) await self._address_ratelimiter.ratelimit(
None, request.getClientIP()
)
result = await self._do_appservice_login(login_submission, appservice) result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and ( elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
): ):
self._address_ratelimiter.ratelimit(request.getClientIP()) await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_jwt_login(login_submission) result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
self._address_ratelimiter.ratelimit(request.getClientIP()) await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_token_login(login_submission) result = await self._do_token_login(login_submission)
else: else:
self._address_ratelimiter.ratelimit(request.getClientIP()) await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_other_login(login_submission) result = await self._do_other_login(login_submission)
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
# too often. This happens here rather than before as we don't # too often. This happens here rather than before as we don't
# necessarily know the user before now. # necessarily know the user before now.
if ratelimit: if ratelimit:
self._account_ratelimiter.ratelimit(user_id.lower()) await self._account_ratelimiter.ratelimit(None, user_id.lower())
if create_non_existent_users: if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id) canonical_uid = await self.auth_handler.check_user_exists(user_id)

View file

@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)
self.identity_handler.ratelimit_request_token_requests(request, "email", email) await self.identity_handler.ratelimit_request_token_requests(
request, "email", email
)
# The email will be sent to the stored address. # The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to # This avoids a potential account hijack by requesting a password reset to
@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(request, "email", email) await self.identity_handler.ratelimit_request_token_requests(
request, "email", email
)
if next_link: if next_link:
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests( await self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn request, "msisdn", msisdn
) )

View file

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hmac import hmac
import logging import logging
import random import random
@ -22,7 +21,7 @@ from typing import List, Union
import synapse import synapse
import synapse.api.auth import synapse.api.auth
import synapse.types import synapse.types
from synapse.api.constants import LoginType from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
InteractiveAuthIncompleteError, InteractiveAuthIncompleteError,
@ -126,7 +125,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(request, "email", email) await self.identity_handler.ratelimit_request_token_requests(
request, "email", email
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email "email", email
@ -208,7 +209,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests( await self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn request, "msisdn", msisdn
) )
@ -406,7 +407,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP() client_addr = request.getClientIP()
self.ratelimiter.ratelimit(client_addr, update=False) await self.ratelimiter.ratelimit(None, client_addr, update=False)
kind = b"user" kind = b"user"
if b"kind" in request.args: if b"kind" in request.args:
@ -428,15 +429,20 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
desired_username = body["username"] desired_username = body["username"]
appservice = None
if self.auth.has_access_token(request):
appservice = self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely # fork off as soon as possible for ASes which have completely
# different registration flows to normal users # different registration flows to normal users
# == Application Service Registration == # == Application Service Registration ==
if appservice: if body.get("type") == APP_SERVICE_REGISTRATION_TYPE:
if not self.auth.has_access_token(request):
raise SynapseError(
400,
"Appservice token must be provided when using a type of m.login.application_service",
)
# Verify the AS
self.auth.get_appservice_by_req(request)
# Set the desired user according to the AS API (which uses the # Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
@ -457,6 +463,11 @@ class RegisterRestServlet(RestServlet):
) )
return 200, result return 200, result
elif self.auth.has_access_token(request):
raise SynapseError(
400,
"An access token should not be provided on requests to /register (except if type is m.login.application_service)",
)
# == Normal User Registration == (everyone else) # == Normal User Registration == (everyone else)
if not self._registration_enabled: if not self._registration_enabled:

View file

@ -164,7 +164,7 @@ class PreviewUrlResource(DirectServeJsonResource):
clock=self.clock, clock=self.clock,
# don't spider URLs more often than once an hour # don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR, expiry_ms=ONE_HOUR,
) ) # type: ExpiringCache[str, ObservableDeferred]
if self._worker_run_media_background_jobs: if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call( self._cleaner_loop = self.clock.looping_call(

View file

@ -51,6 +51,7 @@ from synapse.crypto import context_factory
from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.context_factory import RegularPolicyForHTTPS
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.presence_router import PresenceRouter
from synapse.events.spamcheck import SpamChecker from synapse.events.spamcheck import SpamChecker
from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.events.utils import EventClientSerializer from synapse.events.utils import EventClientSerializer
@ -329,6 +330,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter: def get_registration_ratelimiter(self) -> Ratelimiter:
return Ratelimiter( return Ratelimiter(
store=self.get_datastore(),
clock=self.get_clock(), clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second, rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count, burst_count=self.config.rc_registration.burst_count,
@ -424,6 +426,10 @@ class HomeServer(metaclass=abc.ABCMeta):
else: else:
raise Exception("Workers cannot write typing") raise Exception("Workers cannot write typing")
@cache_in_self
def get_presence_router(self) -> PresenceRouter:
return PresenceRouter(self)
@cache_in_self @cache_in_self
def get_typing_handler(self) -> FollowerTypingHandler: def get_typing_handler(self) -> FollowerTypingHandler:
if self.config.worker.writers.typing == self.get_instance_name(): if self.config.worker.writers.typing == self.get_instance_name():

View file

@ -22,6 +22,7 @@ from typing import (
Callable, Callable,
DefaultDict, DefaultDict,
Dict, Dict,
FrozenSet,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -515,7 +516,7 @@ class StateResolutionHandler:
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True, iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
# #
# stuff for tracking time spent on state-res by room # stuff for tracking time spent on state-res by room
@ -536,7 +537,7 @@ class StateResolutionHandler:
state_groups_ids: Dict[int, StateMap[str]], state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
): ) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should Always generates a new state group (unless we hit the cache), so should

View file

@ -488,7 +488,7 @@ class DatabasePool:
exception_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry],
func: "Callable[..., R]", func: "Callable[..., R]",
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any,
) -> R: ) -> R:
"""Start a new database transaction with the given connection. """Start a new database transaction with the given connection.
@ -622,7 +622,7 @@ class DatabasePool:
func: "Callable[..., R]", func: "Callable[..., R]",
*args: Any, *args: Any,
db_autocommit: bool = False, db_autocommit: bool = False,
**kwargs: Any **kwargs: Any,
) -> R: ) -> R:
"""Starts a transaction on the database and runs a given function """Starts a transaction on the database and runs a given function
@ -682,7 +682,7 @@ class DatabasePool:
func: "Callable[..., R]", func: "Callable[..., R]",
*args: Any, *args: Any,
db_autocommit: bool = False, db_autocommit: bool = False,
**kwargs: Any **kwargs: Any,
) -> R: ) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool. """Wraps the .runWithConnection() method on the underlying db_pool.
@ -775,7 +775,7 @@ class DatabasePool:
desc: str, desc: str,
decoder: Optional[Callable[[Cursor], R]], decoder: Optional[Callable[[Cursor], R]],
query: str, query: str,
*args: Any *args: Any,
) -> R: ) -> R:
"""Runs a single query for a result set. """Runs a single query for a result set.
@ -900,7 +900,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert", desc: str = "simple_upsert",
lock: bool = True, lock: bool = True,
) -> Optional[bool]: ) -> Optional[bool]:
@ -927,6 +927,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated. new entry was created, False if an existing one was updated.
""" """
insertion_values = insertion_values or {}
attempts = 0 attempts = 0
while True: while True:
try: try:
@ -964,7 +966,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True, lock: bool = True,
) -> Optional[bool]: ) -> Optional[bool]:
""" """
@ -982,6 +984,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated. new entry was created, False if an existing one was updated.
""" """
insertion_values = insertion_values or {}
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
self.simple_upsert_txn_native_upsert( self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values txn, table, keyvalues, values, insertion_values=insertion_values
@ -1003,7 +1007,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True, lock: bool = True,
) -> bool: ) -> bool:
""" """
@ -1017,6 +1021,8 @@ class DatabasePool:
Returns True if a new entry was created, False if an existing Returns True if a new entry was created, False if an existing
one was updated. one was updated.
""" """
insertion_values = insertion_values or {}
# We need to lock the table :(, unless we're *really* careful # We need to lock the table :(, unless we're *really* careful
if lock: if lock:
self.engine.lock_table(txn, table) self.engine.lock_table(txn, table)
@ -1077,7 +1083,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Use the native UPSERT functionality in recent PostgreSQL versions. Use the native UPSERT functionality in recent PostgreSQL versions.
@ -1090,7 +1096,7 @@ class DatabasePool:
""" """
allvalues = {} # type: Dict[str, Any] allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(insertion_values) allvalues.update(insertion_values or {})
if not values: if not values:
latter = "NOTHING" latter = "NOTHING"
@ -1513,7 +1519,7 @@ class DatabasePool:
column: str, column: str,
iterable: Iterable[Any], iterable: Iterable[Any],
retcols: Iterable[str], retcols: Iterable[str],
keyvalues: Dict[str, Any] = {}, keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch", desc: str = "simple_select_many_batch",
batch_size: int = 100, batch_size: int = 100,
) -> List[Any]: ) -> List[Any]:
@ -1531,6 +1537,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query batch_size: the number of rows for each select query
""" """
keyvalues = keyvalues or {}
results = [] # type: List[Dict[str, Any]] results = [] # type: List[Dict[str, Any]]
if not iterable: if not iterable:
@ -2059,69 +2067,18 @@ def make_in_list_sql_clause(
KV = TypeVar("KV") KV = TypeVar("KV")
def make_tuple_comparison_clause( def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
) -> Tuple[str, List[KV]]:
"""Returns a tuple comparison SQL clause """Returns a tuple comparison SQL clause
Depending what the SQL engine supports, builds a SQL clause that looks like either Builds a SQL clause that looks like "(a, b) > (?, ?)"
"(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
Args: Args:
database_engine
keys: A set of (column, value) pairs to be compared. keys: A set of (column, value) pairs to be compared.
Returns: Returns:
A tuple of SQL query and the args A tuple of SQL query and the args
""" """
if database_engine.supports_tuple_comparison:
return ( return (
"(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
[k[1] for k in keys], [k[1] for k in keys],
) )
# we want to build a clause
# (a > ?) OR
# (a == ? AND b > ?) OR
# (a == ? AND b == ? AND c > ?)
# ...
# (a == ? AND b == ? AND ... AND z > ?)
#
# or, equivalently:
#
# (a > ? OR (a == ? AND
# (b > ? OR (b == ? AND
# ...
# (y > ? OR (y == ? AND
# z > ?
# ))
# ...
# ))
# ))
#
# which itself is equivalent to (and apparently easier for the query optimiser):
#
# (a >= ? AND (a > ? OR
# (b >= ? AND (b > ? OR
# ...
# (y >= ? AND (y > ? OR
# z > ?
# ))
# ...
# ))
# ))
#
#
clause = ""
args = [] # type: List[KV]
for k, v in keys[:-1]:
clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
args.extend([v, v])
(k, v) = keys[-1]
clause += "%s > ?" % (k,)
args.append(v)
clause += "))" * (len(keys) - 1)
return clause, args

View file

@ -21,6 +21,7 @@ from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
IdGenerator, IdGenerator,
@ -292,6 +293,8 @@ class DataStore(
name: Optional[str] = None, name: Optional[str] = None,
guests: bool = True, guests: bool = True,
deactivated: bool = False, deactivated: bool = False,
order_by: UserSortOrder = UserSortOrder.USER_ID.value,
direction: str = "f",
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
@ -304,6 +307,8 @@ class DataStore(
name: search for local part of user_id or display name name: search for local part of user_id or display name
guests: whether to in include guest users guests: whether to in include guest users
deactivated: whether to include deactivated users deactivated: whether to include deactivated users
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns: Returns:
A tuple of a list of mappings from user to information and a count of total users. A tuple of a list of mappings from user to information and a count of total users.
""" """
@ -312,6 +317,14 @@ class DataStore(
filters = [] filters = []
args = [self.hs.config.server_name] args = [self.hs.config.server_name]
# Set ordering
order_by_column = UserSortOrder(order_by).value
if direction == "b":
order = "DESC"
else:
order = "ASC"
# `name` is in database already in lower case # `name` is in database already in lower case
if name: if name:
filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
@ -339,10 +352,15 @@ class DataStore(
txn.execute(sql, args) txn.execute(sql, args)
count = txn.fetchone()[0] count = txn.fetchone()[0]
sql = ( sql = """
"SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url " SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
+ sql_base {sql_base}
+ " ORDER BY u.name LIMIT ? OFFSET ?" ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ?
""".format(
sql_base=sql_base,
order_by_column=order_by_column,
order=order,
) )
args += [limit, start] args += [limit, start]
txn.execute(sql, args) txn.execute(sql, args)

View file

@ -298,7 +298,6 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# times, which is fine. # times, which is fine.
where_clause, where_args = make_tuple_comparison_clause( where_clause, where_args = make_tuple_comparison_clause(
self.database_engine,
[("user_id", last_user_id), ("device_id", last_device_id)], [("user_id", last_user_id), ("device_id", last_device_id)],
) )

View file

@ -985,7 +985,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn): def _txn(txn):
clause, args = make_tuple_comparison_clause( clause, args = make_tuple_comparison_clause(
self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS] [(x, last_row[x]) for x in KEY_COLS]
) )
sql = """ sql = """
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts

View file

@ -320,8 +320,8 @@ class PersistEventsStore:
txn: LoggingTransaction, txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool, backfilled: bool,
state_delta_for_room: Dict[str, DeltaState] = {}, state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Dict[str, List[str]] = {}, new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
): ):
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
@ -342,6 +342,9 @@ class PersistEventsStore:
extremities. extremities.
""" """
state_delta_for_room = state_delta_for_room or {}
new_forward_extremeties = new_forward_extremeties or {}
all_events_and_contexts = events_and_contexts all_events_and_contexts = events_and_contexts
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering

View file

@ -838,7 +838,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# We want to do a `(topological_ordering, stream_ordering) > (?,?)` # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions # comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause( tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine,
[ [
("events.room_id", last_room_id), ("events.room_id", last_room_id),
("topological_ordering", last_depth), ("topological_ordering", last_depth),

View file

@ -16,7 +16,7 @@
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple, overload from typing import Container, Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names from constantly import NamedConstant, Names
from typing_extensions import Literal from typing_extensions import Literal
@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context( async def get_stripped_room_state_from_event_context(
self, self,
context: EventContext, context: EventContext,
state_types_to_include: List[EventTypes], state_types_to_include: Container[str],
membership_user_id: Optional[str] = None, membership_user_id: Optional[str] = None,
) -> List[JsonDict]: ) -> List[JsonDict]:
""" """

View file

@ -1027,8 +1027,8 @@ class GroupServerStore(GroupServerWorkerStore):
user_id: str, user_id: str,
is_admin: bool = False, is_admin: bool = False,
is_public: bool = True, is_public: bool = True,
local_attestation: dict = None, local_attestation: Optional[dict] = None,
remote_attestation: dict = None, remote_attestation: Optional[dict] = None,
) -> None: ) -> None:
"""Add a user to the group server. """Add a user to the group server.
@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
user_id: str, user_id: str,
membership: str, membership: str,
is_admin: bool = False, is_admin: bool = False,
content: JsonDict = {}, content: Optional[JsonDict] = None,
local_attestation: Optional[dict] = None, local_attestation: Optional[dict] = None,
remote_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None,
is_publicised: bool = False, is_publicised: bool = False,
@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_publicised: Whether this should be publicised. is_publicised: Whether this should be publicised.
""" """
content = content or {}
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert? # TODO: Upsert?
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(

View file

@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method" "media_repository_drop_index_wo_method"
) )
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
"media_repository_drop_index_wo_method_2"
)
class MediaSortOrder(Enum): class MediaSortOrder(Enum):
@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
unique=True, unique=True,
) )
# the original impl of _drop_media_index_without_method was broken (see
# https://github.com/matrix-org/synapse/issues/8649), so we replace the original
# impl with a no-op and run the fixed migration as
# media_repository_drop_index_wo_method_2.
self.db_pool.updates.register_noop_background_update(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD, BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
self._drop_media_index_without_method, self._drop_media_index_without_method,
) )
async def _drop_media_index_without_method(self, progress, batch_size): async def _drop_media_index_without_method(self, progress, batch_size):
"""background update handler which removes the old constraints.
Note that this is only run on postgres.
"""
def f(txn): def f(txn):
txn.execute( txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
) )
txn.execute( txn.execute(
"ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key" "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
) )
await self.db_pool.runInteraction("drop_media_indices_without_method", f) await self.db_pool.runInteraction("drop_media_indices_without_method", f)
await self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
) )
return 1 return 1

View file

@ -521,13 +521,11 @@ class RoomWorkerStore(SQLBaseStore):
) )
@cached(max_entries=10000) @cached(max_entries=10000)
async def get_ratelimit_for_user(self, user_id): async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
"""Check if there are any overrides for ratelimiting for the given """Check if there are any overrides for ratelimiting for the given user
user
Args: Args:
user_id (str) user_id: user ID of the user
Returns: Returns:
RatelimitOverride if there is an override, else None. If the contents RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been of RatelimitOverride are None or 0 then ratelimitng has been
@ -549,6 +547,62 @@ class RoomWorkerStore(SQLBaseStore):
else: else:
return None return None
async def set_ratelimit_for_user(
self, user_id: str, messages_per_second: int, burst_count: int
) -> None:
"""Sets whether a user is set an overridden ratelimit.
Args:
user_id: user ID of the user
messages_per_second: The number of actions that can be performed in a second.
burst_count: How many actions that can be performed before being limited.
"""
def set_ratelimit_txn(txn):
self.db_pool.simple_upsert_txn(
txn,
table="ratelimit_override",
keyvalues={"user_id": user_id},
values={
"messages_per_second": messages_per_second,
"burst_count": burst_count,
},
)
self._invalidate_cache_and_stream(
txn, self.get_ratelimit_for_user, (user_id,)
)
await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
async def delete_ratelimit_for_user(self, user_id: str) -> None:
"""Delete an overridden ratelimit for a user.
Args:
user_id: user ID of the user
"""
def delete_ratelimit_txn(txn):
row = self.db_pool.simple_select_one_txn(
txn,
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=["user_id"],
allow_none=True,
)
if not row:
return
# They are there, delete them.
self.db_pool.simple_delete_one_txn(
txn, "ratelimit_override", keyvalues={"user_id": user_id}
)
self._invalidate_cache_and_stream(
txn, self.get_ratelimit_for_user, (user_id,)
)
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached() @cached()
async def get_retention_policy_for_room(self, room_id): async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room. """Get the retention policy for a given room.

View file

@ -0,0 +1,22 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- drop old constraints on remote_media_cache_thumbnails
--
-- This was originally part of 57.07, but it was done wrong, per
-- https://github.com/matrix-org/synapse/issues/8649, so we do it again.
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx');

View file

@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# FIXME: how should this be cached? # FIXME: how should this be cached?
async def get_filtered_current_state_ids( async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all() self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the """Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result current_state_events table. This may not be as up-to-date as the result
@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Map from type/state_key to event ID. Map from type/state_key to event ID.
""" """
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = (
state_filter or StateFilter.all()
).make_sql_filter_clause()
if not where_clause: if not where_clause:
# We delegate to the cached version # We delegate to the cached version

View file

@ -66,18 +66,37 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class UserSortOrder(Enum): class UserSortOrder(Enum):
""" """
Enum to define the sorting method used when returning users Enum to define the sorting method used when returning users
with get_users_media_usage_paginate with get_users_paginate in __init__.py
and get_users_media_usage_paginate in stats.py
MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest. When moves this to __init__.py gets `builtins.ImportError` with
MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest. `most likely due to a circular import`
MEDIA_LENGTH = ordered by size of uploaded media.
MEDIA_COUNT = ordered by number of uploaded media.
USER_ID = ordered alphabetically by `user_id`. USER_ID = ordered alphabetically by `user_id`.
NAME = ordered alphabetically by `user_id`. This is for compatibility reasons,
as the user_id is returned in the name field in the response in list users admin API.
DISPLAYNAME = ordered alphabetically by `displayname` DISPLAYNAME = ordered alphabetically by `displayname`
GUEST = ordered by `is_guest`
ADMIN = ordered by `admin`
DEACTIVATED = ordered by `deactivated`
USER_TYPE = ordered alphabetically by `user_type`
AVATAR_URL = ordered alphabetically by `avatar_url`
SHADOW_BANNED = ordered by `shadow_banned`
""" """
MEDIA_LENGTH = "media_length" MEDIA_LENGTH = "media_length"
MEDIA_COUNT = "media_count" MEDIA_COUNT = "media_count"
USER_ID = "user_id" USER_ID = "user_id"
NAME = "name"
DISPLAYNAME = "displayname" DISPLAYNAME = "displayname"
GUEST = "is_guest"
ADMIN = "admin"
DEACTIVATED = "deactivated"
USER_TYPE = "user_type"
AVATAR_URL = "avatar_url"
SHADOW_BANNED = "shadow_banned"
class StatsStore(StateDeltasStore): class StatsStore(StateDeltasStore):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count return count
def _get_state_groups_from_groups_txn( def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all() self, txn, groups, state_filter: Optional[StateFilter] = None
): ):
state_filter = state_filter or StateFilter.all()
results = {group: {} for group in groups} results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = state_filter.make_sql_filter_clause()

View file

@ -15,7 +15,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
async def _get_state_for_groups( async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
state_filter = state_filter or StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()

View file

@ -42,14 +42,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
""" """
... ...
@property
@abc.abstractmethod
def supports_tuple_comparison(self) -> bool:
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
...
@property @property
@abc.abstractmethod @abc.abstractmethod
def supports_using_any_list(self) -> bool: def supports_using_any_list(self) -> bool:

View file

@ -47,8 +47,8 @@ class PostgresEngine(BaseDatabaseEngine):
self._version = db_conn.server_version self._version = db_conn.server_version
# Are we on a supported PostgreSQL version? # Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 90500: if not allow_outdated_version and self._version < 90600:
raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.")
with db_conn.cursor() as txn: with db_conn.cursor() as txn:
txn.execute("SHOW SERVER_ENCODING") txn.execute("SHOW SERVER_ENCODING")
@ -129,13 +129,6 @@ class PostgresEngine(BaseDatabaseEngine):
""" """
return True return True
@property
def supports_tuple_comparison(self):
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
return True
@property @property
def supports_using_any_list(self): def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list""" """Do we support using `a = ANY(?)` and passing a list"""

View file

@ -56,14 +56,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
""" """
return self.module.sqlite_version_info >= (3, 24, 0) return self.module.sqlite_version_info >= (3, 24, 0)
@property
def supports_tuple_comparison(self):
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
SQLite 3.15+.
"""
return self.module.sqlite_version_info >= (3, 15, 0)
@property @property
def supports_using_any_list(self): def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list""" """Do we support using `a = ANY(?)` and passing a list"""
@ -72,8 +64,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def check_database(self, db_conn, allow_outdated_version: bool = False): def check_database(self, db_conn, allow_outdated_version: bool = False):
if not allow_outdated_version: if not allow_outdated_version:
version = self.module.sqlite_version_info version = self.module.sqlite_version_info
if version < (3, 11, 0): # Synapse is untested against older SQLite versions, and we don't want
raise RuntimeError("Synapse requires sqlite 3.11 or above.") # to let users upgrade to a version of Synapse with broken support for their
# sqlite version, because it risks leaving them with a half-upgraded db.
if version < (3, 22, 0):
raise RuntimeError("Synapse requires sqlite 3.22 or above.")
def check_new_database(self, txn): def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to """Gets called when setting up a brand new database. This allows us to

View file

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import imp import importlib.util
import logging import logging
import os import os
import re import re
@ -454,8 +454,13 @@ def _upgrade_existing_database(
) )
module_name = "synapse.storage.v%d_%s" % (v, root_name) module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
module = imp.load_source(module_name, absolute_path, python_file) # type: ignore spec = importlib.util.spec_from_file_location(
module_name, absolute_path
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
logger.info("Running script %s", relative_path) logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine) # type: ignore module.run_create(cur, database_engine) # type: ignore
if not is_empty: if not is_empty:

View file

@ -449,7 +449,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events( async def get_state_for_events(
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
@ -465,7 +465,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter or StateFilter.all()
) )
state_event_map = await self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
@ -485,7 +485,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
@ -502,7 +502,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter or StateFilter.all()
) )
event_to_state = { event_to_state = {
@ -513,7 +513,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event( async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all() self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]: ) -> StateMap[EventBase]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@ -525,11 +525,13 @@ class StateGroupStorage:
Returns: Returns:
A dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = await self.get_state_for_events([event_id], state_filter) state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id] return state_map[event_id]
async def get_state_ids_for_event( async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all() self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@ -541,11 +543,13 @@ class StateGroupStorage:
Returns: Returns:
A dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = await self.get_state_ids_for_events([event_id], state_filter) state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id] return state_map[event_id]
def _get_state_for_groups( def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]: ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -558,7 +562,9 @@ class StateGroupStorage:
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def store_state_group( async def store_state_group(
self, self,

View file

@ -17,7 +17,7 @@ import logging
import threading import threading
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import attr import attr
@ -91,7 +91,14 @@ class StreamIdGenerator:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[], step=1): def __init__(
self,
db_conn,
table,
column,
extra_tables: Iterable[Tuple[str, str]] = (),
step=1,
):
assert step != 0 assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._step = step self._step = step

Some files were not shown because too many files have changed in this diff Show more