Merge branch 'develop' into e2e_backups

This commit is contained in:
Hubert Chathi 2018-08-24 11:44:26 -04:00 committed by GitHub
commit 83caead95a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
232 changed files with 7197 additions and 4107 deletions

48
.circleci/config.yml Normal file
View File

@ -0,0 +1,48 @@
version: 2
jobs:
sytestpy2:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy2
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2
- store_artifacts:
path: ~/project/logs
destination: logs
sytestpy2postgres:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy2
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2
- store_artifacts:
path: ~/project/logs
destination: logs
sytestpy3:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs hawkowl/sytestpy3
- store_artifacts:
path: ~/project/logs
destination: logs
sytestpy3postgres:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3
- store_artifacts:
path: ~/project/logs
destination: logs
workflows:
version: 2
build:
jobs:
- sytestpy2
- sytestpy2postgres
# Currently broken while the Python 3 port is incomplete
# - sytestpy3
# - sytestpy3postgres

View File

@ -3,3 +3,6 @@ Dockerfile
.gitignore .gitignore
demo/etc demo/etc
tox.ini tox.ini
synctl
.git/*
.tox/*

View File

@ -8,6 +8,9 @@ before_script:
- git remote set-branches --add origin develop - git remote set-branches --add origin develop
- git fetch origin develop - git fetch origin develop
services:
- postgresql
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
@ -20,6 +23,9 @@ matrix:
- python: 2.7 - python: 2.7
env: TOX_ENV=py27 env: TOX_ENV=py27
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36
@ -29,6 +35,10 @@ matrix:
- python: 3.6 - python: 3.6
env: TOX_ENV=check-newsfragment env: TOX_ENV=check-newsfragment
allow_failures:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
install: install:
- pip install tox - pip install tox

View File

@ -1,3 +1,85 @@
Synapse 0.33.3 (2018-08-22)
===========================
Bugfixes
--------
- Fix bug introduced in v0.33.3rc1 which made the ToS give a 500 error ([\#3732](https://github.com/matrix-org/synapse/issues/3732))
Synapse 0.33.3rc2 (2018-08-21)
==============================
Bugfixes
--------
- Fix bug in v0.33.3rc1 which caused infinite loops and OOMs ([\#3723](https://github.com/matrix-org/synapse/issues/3723))
Synapse 0.33.3rc1 (2018-08-21)
==============================
Features
--------
- Add support for the SNI extension to federation TLS connections. Thanks to @vojeroen! ([\#3439](https://github.com/matrix-org/synapse/issues/3439))
- Add /_media/r0/config ([\#3184](https://github.com/matrix-org/synapse/issues/3184))
- speed up /members API and add `at` and `membership` params as per MSC1227 ([\#3568](https://github.com/matrix-org/synapse/issues/3568))
- implement `summary` block in /sync response as per MSC688 ([\#3574](https://github.com/matrix-org/synapse/issues/3574))
- Add lazy-loading support to /messages as per MSC1227 ([\#3589](https://github.com/matrix-org/synapse/issues/3589))
- Add ability to limit number of monthly active users on the server ([\#3633](https://github.com/matrix-org/synapse/issues/3633))
- Support more federation endpoints on workers ([\#3653](https://github.com/matrix-org/synapse/issues/3653))
- Basic support for room versioning ([\#3654](https://github.com/matrix-org/synapse/issues/3654))
- Ability to disable client/server Synapse via conf toggle ([\#3655](https://github.com/matrix-org/synapse/issues/3655))
- Ability to whitelist specific threepids against monthly active user limiting ([\#3662](https://github.com/matrix-org/synapse/issues/3662))
- Add some metrics for the appservice and federation event sending loops ([\#3664](https://github.com/matrix-org/synapse/issues/3664))
- Where server is disabled, block ability for locked out users to read new messages ([\#3670](https://github.com/matrix-org/synapse/issues/3670))
- set admin uri via config, to be used in error messages where the user should contact the administrator ([\#3687](https://github.com/matrix-org/synapse/issues/3687))
- Synapse's presence functionality can now be disabled with the "use_presence" configuration option. ([\#3694](https://github.com/matrix-org/synapse/issues/3694))
- For resource limit blocked users, prevent writing into rooms ([\#3708](https://github.com/matrix-org/synapse/issues/3708))
Bugfixes
--------
- Fix occasional glitches in the synapse_event_persisted_position metric ([\#3658](https://github.com/matrix-org/synapse/issues/3658))
- Fix bug on deleting 3pid when using identity servers that don't support unbind API ([\#3661](https://github.com/matrix-org/synapse/issues/3661))
- Make the tests pass on Twisted < 18.7.0 ([\#3676](https://github.com/matrix-org/synapse/issues/3676))
- Dont ship recaptcha_ajax.js, use it directly from Google ([\#3677](https://github.com/matrix-org/synapse/issues/3677))
- Fixes test_reap_monthly_active_users so it passes under postgres ([\#3681](https://github.com/matrix-org/synapse/issues/3681))
- Fix mau blocking calulation bug on login ([\#3689](https://github.com/matrix-org/synapse/issues/3689))
- Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users ([\#3692](https://github.com/matrix-org/synapse/issues/3692))
- Improve HTTP request logging to include all requests ([\#3700](https://github.com/matrix-org/synapse/issues/3700))
- Avoid timing out requests while we are streaming back the response ([\#3701](https://github.com/matrix-org/synapse/issues/3701))
- Support more federation endpoints on workers ([\#3705](https://github.com/matrix-org/synapse/issues/3705), [\#3713](https://github.com/matrix-org/synapse/issues/3713))
- Fix "Starting db txn 'get_all_updated_receipts' from sentinel context" warning ([\#3710](https://github.com/matrix-org/synapse/issues/3710))
- Fix bug where `state_cache` cache factor ignored environment variables ([\#3719](https://github.com/matrix-org/synapse/issues/3719))
Deprecations and Removals
-------------------------
- The Shared-Secret registration method of the legacy v1/register REST endpoint has been removed. For a replacement, please see [the admin/register API documentation](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/register_api.rst). ([\#3703](https://github.com/matrix-org/synapse/issues/3703))
Internal Changes
----------------
- The test suite now can run under PostgreSQL. ([\#3423](https://github.com/matrix-org/synapse/issues/3423))
- Refactor HTTP replication endpoints to reduce code duplication ([\#3632](https://github.com/matrix-org/synapse/issues/3632))
- Tests now correctly execute on Python 3. ([\#3647](https://github.com/matrix-org/synapse/issues/3647))
- Sytests can now be run inside a Docker container. ([\#3660](https://github.com/matrix-org/synapse/issues/3660))
- Port over enough to Python 3 to allow the sytests to start. ([\#3668](https://github.com/matrix-org/synapse/issues/3668))
- Update docker base image from alpine 3.7 to 3.8. ([\#3669](https://github.com/matrix-org/synapse/issues/3669))
- Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7. ([\#3678](https://github.com/matrix-org/synapse/issues/3678))
- Synapse's tests are now formatted with the black autoformatter. ([\#3679](https://github.com/matrix-org/synapse/issues/3679))
- Implemented a new testing base class to reduce test boilerplate. ([\#3684](https://github.com/matrix-org/synapse/issues/3684))
- Rename MAU prometheus metrics ([\#3690](https://github.com/matrix-org/synapse/issues/3690))
- add new error type ResourceLimit ([\#3707](https://github.com/matrix-org/synapse/issues/3707))
- Logcontexts for replication command handlers ([\#3709](https://github.com/matrix-org/synapse/issues/3709))
- Update admin register API documentation to reference a real user ID. ([\#3712](https://github.com/matrix-org/synapse/issues/3712))
Synapse 0.33.2 (2018-08-09) Synapse 0.33.2 (2018-08-09)
=========================== ===========================
@ -24,7 +106,7 @@ Features
Bugfixes Bugfixes
-------- --------
- Make /directory/list API return 404 for room not found instead of 400 ([\#2952](https://github.com/matrix-org/synapse/issues/2952)) - Make /directory/list API return 404 for room not found instead of 400. Thanks to @fuzzmz! ([\#3620](https://github.com/matrix-org/synapse/issues/3620))
- Default inviter_display_name to mxid for email invites ([\#3391](https://github.com/matrix-org/synapse/issues/3391)) - Default inviter_display_name to mxid for email invites ([\#3391](https://github.com/matrix-org/synapse/issues/3391))
- Don't generate TURN credentials if no TURN config options are set ([\#3514](https://github.com/matrix-org/synapse/issues/3514)) - Don't generate TURN credentials if no TURN config options are set ([\#3514](https://github.com/matrix-org/synapse/issues/3514))
- Correctly announce deleted devices over federation ([\#3520](https://github.com/matrix-org/synapse/issues/3520)) - Correctly announce deleted devices over federation ([\#3520](https://github.com/matrix-org/synapse/issues/3520))

View File

@ -56,17 +56,17 @@ entry. These are managed by Towncrier
(https://github.com/hawkowl/towncrier). (https://github.com/hawkowl/towncrier).
To create a changelog entry, make a new file in the ``changelog.d`` To create a changelog entry, make a new file in the ``changelog.d``
file named in the format of ``issuenumberOrPR.type``. The type can be file named in the format of ``PRnumber.type``. The type can be
one of ``feature``, ``bugfix``, ``removal`` (also used for one of ``feature``, ``bugfix``, ``removal`` (also used for
deprecations), or ``misc`` (for internal-only changes). The content of deprecations), or ``misc`` (for internal-only changes). The content of
the file is your changelog entry, which can contain RestructuredText the file is your changelog entry, which can contain RestructuredText
formatting. A note of contributors is welcomed in changelogs for formatting. A note of contributors is welcomed in changelogs for
non-misc changes (the content of misc changes is not displayed). non-misc changes (the content of misc changes is not displayed).
For example, a fix for a bug reported in #1234 would have its For example, a fix in PR #1234 would have its changelog entry in
changelog entry in ``changelog.d/1234.bugfix``, and contain content ``changelog.d/1234.bugfix``, and contain content like "The security levels of
like "The security levels of Florbs are now validated when Florbs are now validated when recieved over federation. Contributed by Jane
recieved over federation. Contributed by Jane Matrix". Matrix".
Attribution Attribution
~~~~~~~~~~~ ~~~~~~~~~~~

View File

@ -36,3 +36,4 @@ recursive-include changelog.d *
prune .github prune .github
prune demo/etc prune demo/etc
prune docker prune docker
prune .circleci

1
changelog.d/3659.feature Normal file
View File

@ -0,0 +1 @@
Support profile API endpoints on workers

1
changelog.d/3673.misc Normal file
View File

@ -0,0 +1 @@
Refactor state module to support multiple room versions

1
changelog.d/3680.feature Normal file
View File

@ -0,0 +1 @@
Server notices for resource limit blocking

1
changelog.d/3722.bugfix Normal file
View File

@ -0,0 +1 @@
Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues

1
changelog.d/3724.feature Normal file
View File

@ -0,0 +1 @@
Allow guests to use /rooms/:roomId/event/:eventId

1
changelog.d/3726.misc Normal file
View File

@ -0,0 +1 @@
Split the state_group_cache into member and non-member state events (and so speed up LL /sync)

1
changelog.d/3727.misc Normal file
View File

@ -0,0 +1 @@
Log failure to authenticate remote servers as warnings (without stack traces)

1
changelog.d/3734.misc Normal file
View File

@ -0,0 +1 @@
Reference the need for an HTTP replication port when using the federation_reader worker

1
changelog.d/3735.misc Normal file
View File

@ -0,0 +1 @@
Fix minor spelling error in federation client documentation.

1
changelog.d/3746.misc Normal file
View File

@ -0,0 +1 @@
Fix MAU cache invalidation due to missing yield

1
changelog.d/3747.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where we resent "limit exceeded" server notices repeatedly

1
changelog.d/3749.feature Normal file
View File

@ -0,0 +1 @@
Add mau_trial_days config param, so that users only get counted as MAU after N days.

1
changelog.d/3751.feature Normal file
View File

@ -0,0 +1 @@
Require twisted 17.1 or later (fixes [#3741](https://github.com/matrix-org/synapse/issues/3741)).

1
changelog.d/3753.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where we broke sync when using limit_usage_by_mau but hadn't configured server notices

1
changelog.d/3754.bugfix Normal file
View File

@ -0,0 +1 @@
Fix 'federation_domain_whitelist' such that an empty list correctly blocks all outbound federation traffic

1
changelog.d/3755.bugfix Normal file
View File

@ -0,0 +1 @@
Fix tagging of server notice rooms

View File

@ -54,7 +54,7 @@
"gnetId": null, "gnetId": null,
"graphTooltip": 0, "graphTooltip": 0,
"id": null, "id": null,
"iteration": 1533026624326, "iteration": 1533598785368,
"links": [ "links": [
{ {
"asDropdown": true, "asDropdown": true,
@ -4629,7 +4629,7 @@
"h": 9, "h": 9,
"w": 12, "w": 12,
"x": 0, "x": 0,
"y": 11 "y": 29
}, },
"id": 67, "id": 67,
"legend": { "legend": {
@ -4655,11 +4655,11 @@
"steppedLine": false, "steppedLine": false,
"targets": [ "targets": [
{ {
"expr": " synapse_event_persisted_position{instance=\"$instance\"} - ignoring(index, job, name) group_right(instance) synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", "expr": " synapse_event_persisted_position{instance=\"$instance\",job=\"synapse\"} - ignoring(index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
"format": "time_series", "format": "time_series",
"interval": "", "interval": "",
"intervalFactor": 1, "intervalFactor": 1,
"legendFormat": "{{job}}-{{index}}", "legendFormat": "{{job}}-{{index}} ",
"refId": "A" "refId": "A"
} }
], ],
@ -4697,7 +4697,11 @@
"min": null, "min": null,
"show": true "show": true
} }
] ],
"yaxis": {
"align": false,
"alignLevel": null
}
}, },
{ {
"aliasColors": {}, "aliasColors": {},
@ -4710,7 +4714,7 @@
"h": 9, "h": 9,
"w": 12, "w": 12,
"x": 12, "x": 12,
"y": 11 "y": 29
}, },
"id": 71, "id": 71,
"legend": { "legend": {
@ -4778,7 +4782,11 @@
"min": null, "min": null,
"show": true "show": true
} }
] ],
"yaxis": {
"align": false,
"alignLevel": null
}
} }
], ],
"title": "Event processing loop positions", "title": "Event processing loop positions",
@ -4957,5 +4965,5 @@
"timezone": "", "timezone": "",
"title": "Synapse", "title": "Synapse",
"uid": "000000012", "uid": "000000012",
"version": 125 "version": 127
} }

View File

@ -1,4 +1,4 @@
FROM docker.io/python:2-alpine3.7 FROM docker.io/python:2-alpine3.8
RUN apk add --no-cache --virtual .nacl_deps \ RUN apk add --no-cache --virtual .nacl_deps \
build-base \ build-base \

View File

@ -33,7 +33,7 @@ As an example::
< { < {
"access_token": "token_here", "access_token": "token_here",
"user_id": "@pepper_roni@test", "user_id": "@pepper_roni:localhost",
"home_server": "test", "home_server": "test",
"device_id": "device_id_here" "device_id": "device_id_here"
} }

View File

@ -74,7 +74,7 @@ replication endpoints that it's talking to on the main synapse process.
``worker_replication_port`` should point to the TCP replication listener port and ``worker_replication_port`` should point to the TCP replication listener port and
``worker_replication_http_port`` should point to the HTTP replication port. ``worker_replication_http_port`` should point to the HTTP replication port.
Currently, only the ``event_creator`` worker requires specifying Currently, the ``event_creator`` and ``federation_reader`` workers require specifying
``worker_replication_http_port``. ``worker_replication_http_port``.
For instance:: For instance::
@ -173,10 +173,23 @@ endpoints matching the following regular expressions::
^/_matrix/federation/v1/backfill/ ^/_matrix/federation/v1/backfill/
^/_matrix/federation/v1/get_missing_events/ ^/_matrix/federation/v1/get_missing_events/
^/_matrix/federation/v1/publicRooms ^/_matrix/federation/v1/publicRooms
^/_matrix/federation/v1/query/
^/_matrix/federation/v1/make_join/
^/_matrix/federation/v1/make_leave/
^/_matrix/federation/v1/send_join/
^/_matrix/federation/v1/send_leave/
^/_matrix/federation/v1/invite/
^/_matrix/federation/v1/query_auth/
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/send/
The above endpoints should all be routed to the federation_reader worker by the The above endpoints should all be routed to the federation_reader worker by the
reverse-proxy configuration. reverse-proxy configuration.
The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single
instance.
``synapse.app.federation_sender`` ``synapse.app.federation_sender``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -228,6 +241,14 @@ regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/keys/upload ^/_matrix/client/(api/v1|r0|unstable)/keys/upload
If ``use_presence`` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status
This "stub" presence handler will pass through ``GET`` request but make the
``PUT`` effectively a no-op.
It will proxy any requests it cannot handle to the main synapse instance. It It will proxy any requests it cannot handle to the main synapse instance. It
must therefore be configured with the location of the main instance, via must therefore be configured with the location of the main instance, via
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
@ -244,6 +265,7 @@ Handles some event creation. It can handle REST endpoints matching::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
^/_matrix/client/(api/v1|r0|unstable)/join/ ^/_matrix/client/(api/v1|r0|unstable)/join/
^/_matrix/client/(api/v1|r0|unstable)/profile/
It will create events locally and then send them on to the main synapse It will create events locally and then send them on to the main synapse
instance to be persisted and handled. instance to be persisted and handled.

View File

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.33.2" __version__ = "0.33.3"

View File

@ -25,7 +25,7 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -211,9 +211,9 @@ class Auth(object):
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", b"User-Agent",
default=[b""] default=[b""]
)[0] )[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -682,7 +682,7 @@ class Auth(object):
Returns: Returns:
bool: False if no access_token was given, True otherwise. bool: False if no access_token was given, True otherwise.
""" """
query_params = request.args.get("access_token") query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@ -698,7 +698,7 @@ class Auth(object):
401 since some of the old clients depended on auth errors returning 401 since some of the old clients depended on auth errors returning
403. 403.
Returns: Returns:
str: The access_token unicode: The access_token
Raises: Raises:
AuthError: If there isn't an access_token in the request. AuthError: If there isn't an access_token in the request.
""" """
@ -720,9 +720,9 @@ class Auth(object):
"Too many Authorization headers.", "Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN, errcode=Codes.MISSING_TOKEN,
) )
parts = auth_headers[0].split(" ") parts = auth_headers[0].split(b" ")
if parts[0] == "Bearer" and len(parts) == 2: if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1] return parts[1].decode('ascii')
else: else:
raise AuthError( raise AuthError(
token_not_found_http_status, token_not_found_http_status,
@ -738,7 +738,7 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN errcode=Codes.MISSING_TOKEN
) )
return query_params[0] return query_params[0].decode('ascii')
@defer.inlineCallbacks @defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id): def check_in_room_or_world_readable(self, room_id, user_id):
@ -773,3 +773,46 @@ class Auth(object):
raise AuthError( raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
) )
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
"""
# Never fail an auth check for the server notices users
# This can be a problem where event creation is prohibited due to blocking
if user_id == self.hs.config.server_notices_mxid:
return
if self.hs.config.hs_disabled:
raise ResourceLimitError(
403, self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_uri=self.hs.config.admin_uri,
limit_type=self.hs.config.hs_disabled_limit_type
)
if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403, "Monthly Active User Limit Exceeded",
admin_uri=self.hs.config.admin_uri,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type="monthly_active_user"
)

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -77,6 +78,7 @@ class EventTypes(object):
Name = "m.room.name" Name = "m.room.name"
ServerACL = "m.room.server_acl" ServerACL = "m.room.server_acl"
Pinned = "m.room.pinned_events"
class RejectedReason(object): class RejectedReason(object):
@ -94,3 +96,19 @@ class RoomCreationPreset(object):
class ThirdPartyEntityKind(object): class ThirdPartyEntityKind(object):
USER = "user" USER = "user"
LOCATION = "location" LOCATION = "location"
class RoomVersions(object):
V1 = "1"
VDH_TEST = "vdh-test-version"
# the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = RoomVersions.V1
# vdh-test-version is a placeholder to get room versioning support working and tested
# until we have a working v2.
KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST}
ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -55,7 +56,9 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" RESOURCE_LIMIT_EXCEEDED = "M_RESOURCE_LIMIT_EXCEEDED"
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
@ -229,6 +232,30 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs) super(AuthError, self).__init__(*args, **kwargs)
class ResourceLimitError(SynapseError):
"""
Any error raised when there is a problem with resource usage.
For instance, the monthly active user limit for the server has been exceeded
"""
def __init__(
self, code, msg,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_uri=None,
limit_type=None,
):
self.admin_uri = admin_uri
self.limit_type = limit_type
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
admin_uri=self.admin_uri,
limit_type=self.limit_type
)
class EventSizeError(SynapseError): class EventSizeError(SynapseError):
"""An error raised when an event is too big.""" """An error raised when an event is too big."""
@ -299,11 +326,24 @@ class RoomKeysVersionError(SynapseError):
) )
self.current_version = current_version self.current_version = current_version
class IncompatibleRoomVersionError(SynapseError):
"""A server is trying to join a room whose version it does not support."""
def __init__(self, room_version):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
msg="Your homeserver does not support the features required to "
"join this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
self._room_version = room_version
def error_dict(self): def error_dict(self):
return cs_error( return cs_error(
self.msg, self.msg,
self.errcode, self.errcode,
current_version=self.current_version, room_version=self.current_version,
) )

View File

@ -72,7 +72,7 @@ class Ratelimiter(object):
return allowed, time_allowed return allowed, time_allowed
def prune_message_counts(self, time_now_s): def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys(): for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = ( message_count, time_start, msg_rate_hz = (
self.message_counts[user_id] self.message_counts[user_id]
) )

View File

@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
logger.info("Metrics now reporting on %s:%d", host, port) logger.info("Metrics now reporting on %s:%d", host, port)
def listen_tcp(bind_addresses, port, factory, backlog=50): def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
""" """
Create a TCP socket for a port and several addresses Create a TCP socket for a port and several addresses
""" """
@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
check_bind_error(e, address, bind_addresses) check_bind_error(e, address, bind_addresses)
def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
""" """
Create an SSL socket for a port and several addresses Create an SSL socket for a port and several addresses
""" """

View File

@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore()) super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events": if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering() max_stream_id = self.store.get_room_max_stream_ordering()

View File

@ -39,7 +39,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import ( from synapse.rest.client.v1.room import (
JoinedRoomMemberListRestServlet, JoinedRoomMemberListRestServlet,
@ -66,7 +66,7 @@ class ClientReaderSlavedStore(
DirectoryStore, DirectoryStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore, SlavedTransactionStore,
SlavedClientIpStore, SlavedClientIpStore,
BaseSlavedStore, BaseSlavedStore,
): ):
@ -168,11 +168,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = ClientReaderServer( ss = ClientReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -43,8 +43,13 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.profile import (
ProfileAvatarURLRestServlet,
ProfileDisplaynameRestServlet,
ProfileRestServlet,
)
from synapse.rest.client.v1.room import ( from synapse.rest.client.v1.room import (
JoinRoomAliasServlet, JoinRoomAliasServlet,
RoomMembershipRestServlet, RoomMembershipRestServlet,
@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import (
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -62,8 +68,11 @@ logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore( class EventCreatorSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
DirectoryStore, DirectoryStore,
TransactionStore, SlavedTransactionStore,
SlavedProfileStore, SlavedProfileStore,
SlavedAccountDataStore, SlavedAccountDataStore,
SlavedPusherStore, SlavedPusherStore,
@ -101,6 +110,9 @@ class EventCreatorServer(HomeServer):
RoomMembershipRestServlet(self).register(resource) RoomMembershipRestServlet(self).register(resource)
RoomStateEventRestServlet(self).register(resource) RoomStateEventRestServlet(self).register(resource)
JoinRoomAliasServlet(self).register(resource) JoinRoomAliasServlet(self).register(resource)
ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
@ -174,11 +186,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = EventCreatorServer( ss = EventCreatorServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -32,11 +32,17 @@ from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -49,11 +55,17 @@ logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore( class FederationReaderSlavedStore(
SlavedAccountDataStore,
SlavedProfileStore,
SlavedApplicationServiceStore,
SlavedPusherStore,
SlavedPushRuleStore,
SlavedReceiptsStore,
SlavedEventStore, SlavedEventStore,
SlavedKeyStore, SlavedKeyStore,
RoomStore, RoomStore,
DirectoryStore, DirectoryStore,
TransactionStore, SlavedTransactionStore,
BaseSlavedStore, BaseSlavedStore,
): ):
pass pass
@ -143,11 +155,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FederationReaderServer( ss = FederationReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -36,11 +36,11 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -50,7 +50,7 @@ logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self) self.send_handler = FederationSenderHandler(hs, self)
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(FederationSenderReplicationHandler, self).on_rdata( yield super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
self.send_handler.process_replication_rows(stream_name, token, rows) self.send_handler.process_replication_rows(stream_name, token, rows)
@ -186,11 +187,13 @@ def start(config_options):
config.send_federation = True config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = FederationSenderServer( ps = FederationSenderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy") logger = logging.getLogger("synapse.app.frontend_proxy")
class PresenceStatusStubServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__(hs)
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri
@defer.inlineCallbacks
def on_GET(self, request, user_id):
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.get_json(
self.main_uri + request.uri,
headers=headers,
)
defer.returnValue((200, result))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
yield self.auth.get_user_by_req(request)
defer.returnValue((200, {}))
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource) KeyUploadServlet(self).register(resource)
# If presence is disabled, use the stub servlet that does
# not allow sending presence
if not self.config.use_presence:
PresenceStatusStubServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
self.version_string, self.version_string,
) ),
reactor=self.get_reactor()
) )
logger.info("Synapse client reader now listening on port %d", port) logger.info("Synapse client reader now listening on port %d", port)
@ -208,11 +245,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FrontendProxyServer( ss = FrontendProxyServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -303,8 +303,8 @@ class SynapseHomeServer(HomeServer):
# Gauges to expose monthly active user control metrics # Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
def setup(config_options): def setup(config_options):
@ -338,6 +338,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
@ -346,6 +347,7 @@ def setup(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
@ -519,17 +521,27 @@ def run(hs):
# table will decrease # table will decrease
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
# monthly active user limiting functionality
clock.looping_call(
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
)
hs.get_datastore().reap_monthly_active_users()
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
count = 0 count = 0
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().count_monthly_users() count = yield hs.get_datastore().get_monthly_active_count()
current_mau_gauge.set(float(count)) current_mau_gauge.set(float(count))
max_mau_value_gauge.set(float(hs.config.max_mau_value)) max_mau_gauge.set(float(hs.config.max_mau_value))
hs.get_datastore().initialise_reserved_users(
hs.config.mau_limits_reserved_threepids
)
generate_monthly_active_users() generate_monthly_active_users()
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
if hs.config.report_stats: if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals") logger.info("Scheduling stats reporting for 3 hour intervals")

View File

@ -34,7 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -52,7 +52,7 @@ class MediaRepositorySlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedClientIpStore, SlavedClientIpStore,
TransactionStore, SlavedTransactionStore,
BaseSlavedStore, BaseSlavedStore,
MediaRepositoryStore, MediaRepositoryStore,
): ):
@ -155,11 +155,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = MediaRepositoryServer( ss = MediaRepositoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else: else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey) yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events": elif stream_name == "events":
yield self.pusher_pool.on_new_notifications( self.pusher_pool.on_new_notifications(
token, token, token, token,
) )
elif stream_name == "receipts": elif stream_name == "receipts":
yield self.pusher_pool.on_new_receipts( self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows) token, token, set(row.room_id for row in rows)
) )
except Exception: except Exception:

View File

@ -114,7 +114,10 @@ class SynchrotronPresence(object):
logger.info("Presence process_id is %r", self.process_id) logger.info("Presence process_id is %r", self.process_id)
def send_user_sync(self, user_id, is_syncing, last_sync_ms): def send_user_sync(self, user_id, is_syncing, last_sync_ms):
self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id): def mark_as_coming_online(self, user_id):
"""A user has started syncing. Send a UserSync to the master, unless they """A user has started syncing. Send a UserSync to the master, unless they
@ -211,10 +214,13 @@ class SynchrotronPresence(object):
yield self.notify_from_replication(states, stream_id) yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
return [ if self.hs.config.use_presence:
user_id for user_id, count in iteritems(self.user_to_num_current_syncs) return [
if count > 0 user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
] if count > 0
]
else:
return set()
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -332,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self): def get_streams_to_replicate(self):

View File

@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler() self.user_directory = hs.get_user_directory_handler()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(UserDirectoryReplicationHandler, self).on_rdata( yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
if stream_name == "current_state_deltas": if stream_name == "current_state_deltas":
@ -214,11 +215,13 @@ def start(config_options):
config.update_user_directory = True config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = UserDirectoryServer( ps = UserDirectoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
if log_file: if log_file:
# TODO: Customisable file size / backup count # TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler( handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
encoding='utf8'
) )
def sighup(signum, stack): def sighup(signum, stack):
@ -193,9 +194,8 @@ def setup_logging(config, use_worker_options=False):
def sighup(signum, stack): def sighup(signum, stack):
# it might be better to use a file watcher or something for this. # it might be better to use a file watcher or something for this.
logging.info("Reloading log config from %s due to SIGHUP",
log_config)
load_log_config() load_log_config()
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config() load_log_config()

View File

@ -49,6 +49,9 @@ class ServerConfig(Config):
# "disable" federation # "disable" federation
self.send_federation = config.get("send_federation", True) self.send_federation = config.get("send_federation", True)
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)
# Whether to update the user directory or not. This should be set to # Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker # false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True) self.update_user_directory = config.get("update_user_directory", True)
@ -69,12 +72,29 @@ class ServerConfig(Config):
# Options to control access by tracking MAU # Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
if self.limit_usage_by_mau: if self.limit_usage_by_mau:
self.max_mau_value = config.get( self.max_mau_value = config.get(
"max_mau_value", 0, "max_mau_value", 0,
) )
else:
self.max_mau_value = 0 self.mau_limits_reserved_threepids = config.get(
"mau_limit_reserved_threepids", []
)
self.mau_trial_days = config.get(
"mau_trial_days", 0,
)
# Options to disable HS
self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "")
self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")
# Admin uri to direct users at should their instance become blocked
# due to resource constraints
self.admin_uri = config.get("admin_uri", None)
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(
@ -238,6 +258,9 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver.
use_presence: true
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]
@ -329,6 +352,33 @@ class ServerConfig(Config):
# - port: 9000 # - port: 9000
# bind_addresses: ['::1', '127.0.0.1'] # bind_addresses: ['::1', '127.0.0.1']
# type: manhole # type: manhole
# Homeserver blocking
#
# How to reach the server admin, used in ResourceLimitError
# admin_uri: 'mailto:admin@server.com'
#
# Global block config
#
# hs_disabled: False
# hs_disabled_message: 'Human readable reason for why the HS is blocked'
# hs_disabled_limit_type: 'error code(str), to help clients decode reason'
#
# Monthly Active User Blocking
#
# Enables monthly active user checking
# limit_usage_by_mau: False
# max_mau_value: 50
# mau_trial_days: 2
#
# Sometimes the server admin will want to ensure certain accounts are
# never blocked by mau checking. These accounts are specified here.
#
# mau_limit_reserved_threepids:
# - medium: 'email'
# address: 'reserved_user@example.com'
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):

View File

@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from zope.interface import implementer
from OpenSSL import SSL, crypto from OpenSSL import SSL, crypto
from twisted.internet import ssl
from twisted.internet._sslverify import _defaultCurveName from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ServerContextFactory(ssl.ContextFactory): class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming """Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections and to make connections to remote servers.""" connections."""
def __init__(self, config): def __init__(self, config):
self._context = SSL.Context(SSL.SSLv23_METHOD) self._context = SSL.Context(SSL.SSLv23_METHOD)
@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory):
def getContext(self): def getContext(self):
return self._context return self._context
def _idnaBytes(text):
"""
Convert some text typed by a human into some ASCII bytes. This is a
copy of twisted.internet._idna._idnaBytes. For documentation, see the
twisted documentation.
"""
try:
import idna
except ImportError:
return text.encode("idna")
else:
return idna.encode(text)
def _tolerateErrors(wrapped):
"""
Wrap up an info_callback for pyOpenSSL so that if something goes wrong
the error is immediately logged and the connection is dropped if possible.
This is a copy of twisted.internet._sslverify._tolerateErrors. For
documentation, see the twisted documentation.
"""
def infoCallback(connection, where, ret):
try:
return wrapped(connection, where, ret)
except: # noqa: E722, taken from the twisted implementation
f = Failure()
logger.exception("Error during info_callback")
connection.get_app_data().failVerification(f)
return infoCallback
@implementer(IOpenSSLClientConnectionCreator)
class ClientTLSOptions(object):
"""
Client creator for TLS without certificate identity verification. This is a
copy of twisted.internet._sslverify.ClientTLSOptions with the identity
verification left out. For documentation, see the twisted documentation.
"""
def __init__(self, hostname, ctx):
self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname)
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
def clientConnectionForTLS(self, tlsProtocol):
context = self._ctx
connection = SSL.Connection(context, None)
connection.set_app_data(tlsProtocol)
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START:
connection.set_tlsext_host_name(self._hostnameBytes)
class ClientTLSOptionsFactory(object):
"""Factory for Twisted ClientTLSOptions that are used to make connections
to remote servers for federation."""
def __init__(self, config):
# We don't use config options yet
pass
def get_options(self, host):
return ClientTLSOptions(
host.decode('utf-8'),
CertificateOptions(verify=False).getContext()
)

View File

@ -18,7 +18,9 @@ import logging
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.names.error import DomainError
from twisted.web.http import HTTPClient from twisted.web.http import HTTPClient
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
@ -30,14 +32,14 @@ KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path factory.path = path
factory.host = server_name factory.host = server_name
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, tls_client_options_factory, timeout=30
) )
for i in range(5): for i in range(5):
@ -47,12 +49,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
server_response, server_certificate = yield protocol.remote_key server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate)) defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,)) logger.warn("Error getting key for %r: %s", server_name, e)
if e.status.startswith("4"): if e.status.startswith("4"):
# Don't retry for 4xx responses. # Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e:
logger.warn("Error getting key for %r: %s", server_name, e)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)

View File

@ -512,7 +512,7 @@ class Keyring(object):
continue continue
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory, server_name, self.hs.tls_client_options_factory,
path=(b"/_matrix/key/v2/server/%s" % ( path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.quote(requested_key_id),
)).encode("ascii"), )).encode("ascii"),
@ -655,7 +655,7 @@ class Keyring(object):
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory server_name, self.hs.tls_client_options_factory
) )
# Check the response. # Check the response.

View File

@ -20,7 +20,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
@ -83,6 +83,14 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
403, 403,
"Creation event's room_id domain does not match sender's" "Creation event's room_id domain does not match sender's"
) )
room_version = event.content.get("room_version", "1")
if room_version not in KNOWN_ROOM_VERSIONS:
raise AuthError(
403,
"room appears to have unsupported version %s" % (
room_version,
))
# FIXME # FIXME
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
return return

View File

@ -25,7 +25,7 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
FederationDeniedError, FederationDeniedError,
@ -518,10 +518,10 @@ class FederationClient(FederationBase):
description, destination, exc_info=1, description, destination, exc_info=1,
) )
raise RuntimeError("Failed to %s via any server", description) raise RuntimeError("Failed to %s via any server" % (description, ))
def make_membership_event(self, destinations, room_id, user_id, membership, def make_membership_event(self, destinations, room_id, user_id, membership,
content={},): content, params):
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -537,8 +537,10 @@ class FederationClient(FederationBase):
user_id (str): The user whose membership is being evented. user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be membership (str): The "membership" property of the event. Must be
one of "join" or "leave". one of "join" or "leave".
content (object): Any additional data to put into the content field content (dict): Any additional data to put into the content field
of the event. of the event.
params (dict[str, str|Iterable[str]]): Query parameters to include in the
request.
Return: Return:
Deferred: resolves to a tuple of (origin (str), event (object)) Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event. where origin is the remote homeserver which generated the event.
@ -558,10 +560,12 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(destination): def send_request(destination):
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership destination, room_id, user_id, membership, params,
) )
pdu_dict = ret["event"] pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
logger.debug("Got response to make_%s: %s", membership, pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
@ -605,6 +609,26 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
def check_authchain_validity(signed_auth_chain):
for e in signed_auth_chain:
if e.type == EventTypes.Create:
create_event = e
break
else:
raise InvalidResponseError(
"no %s in auth chain" % (EventTypes.Create,),
)
# the room version should be sane.
room_version = create_event.content.get("room_version", "1")
if room_version not in KNOWN_ROOM_VERSIONS:
# This shouldn't be possible, because the remote server should have
# rejected the join attempt during make_join.
raise InvalidResponseError(
"room appears to have unsupported version %s" % (
room_version,
))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(destination): def send_request(destination):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -661,7 +685,7 @@ class FederationClient(FederationBase):
for s in signed_state: for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata) s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth) check_authchain_validity(signed_auth)
defer.returnValue({ defer.returnValue({
"state": signed_state, "state": signed_state,

View File

@ -27,14 +27,24 @@ from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError from synapse.api.errors import (
AuthError,
FederationError,
IncompatibleRoomVersionError,
NotFoundError,
SynapseError,
)
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_server_name
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import async from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -61,8 +71,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler self.handler = hs.get_handlers().federation_handler
self._server_linearizer = async.Linearizer("fed_server") self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler") self._transaction_linearizer = Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
@ -194,7 +204,7 @@ class FederationServer(FederationBase):
event_id, f.getTraceback().rstrip(), event_id, f.getTraceback().rstrip(),
) )
yield async.concurrently_execute( yield concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT, TRANSACTION_CONCURRENCY_LIMIT,
) )
@ -323,12 +333,21 @@ class FederationServer(FederationBase):
defer.returnValue((200, resp)) defer.returnValue((200, resp))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, origin, room_id, user_id): def on_make_join_request(self, origin, room_id, user_id, supported_versions):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id) yield self.check_server_matches_acl(origin_host, room_id)
room_version = yield self.store.get_room_version(room_id)
if room_version not in supported_versions:
logger.warn("Room version %s not in %s", room_version, supported_versions)
raise IncompatibleRoomVersionError(room_version=room_version)
pdu = yield self.handler.on_make_join_request(room_id, user_id) pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content): def on_invite_request(self, origin, content):
@ -745,6 +764,8 @@ class FederationHandlerRegistry(object):
if edu_type in self.edu_handlers: if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,)) raise KeyError("Already have an EDU handler for %s" % (edu_type,))
logger.info("Registering federation EDU handler for %r", edu_type)
self.edu_handlers[edu_type] = handler self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler): def register_query_handler(self, query_type, handler):
@ -763,6 +784,8 @@ class FederationHandlerRegistry(object):
"Already have a Query handler for %s" % (query_type,) "Already have a Query handler for %s" % (query_type,)
) )
logger.info("Registering federation query handler for %r", query_type)
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
@defer.inlineCallbacks @defer.inlineCallbacks
@ -785,3 +808,49 @@ class FederationHandlerRegistry(object):
raise NotFoundError("No handler for Query type '%s'" % (query_type,)) raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args) return handler(args)
class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
"""A FederationHandlerRegistry for worker processes.
When receiving EDU or queries it will check if an appropriate handler has
been registered on the worker, if there isn't one then it calls off to the
master process.
"""
def __init__(self, hs):
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
super(ReplicationFederationHandlerRegistry, self).__init__()
def on_edu(self, edu_type, origin, content):
"""Overrides FederationHandlerRegistry
"""
handler = self.edu_handlers.get(edu_type)
if handler:
return super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content,
)
return self._send_edu(
edu_type=edu_type,
origin=origin,
content=content,
)
def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
return handler(args)
return self._get_query_client(
query_type=query_type,
args=args,
)

View File

@ -26,6 +26,8 @@ from synapse.api.errors import FederationDeniedError, HttpResponseException
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import ( from synapse.metrics import (
LaterGauge, LaterGauge,
event_processing_loop_counter,
event_processing_loop_room_count,
events_processed_counter, events_processed_counter,
sent_edus_counter, sent_edus_counter,
sent_transactions_counter, sent_transactions_counter,
@ -56,6 +58,7 @@ class TransactionQueue(object):
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -253,7 +256,13 @@ class TransactionQueue(object):
synapse.metrics.event_processing_last_ts.labels( synapse.metrics.event_processing_last_ts.labels(
"federation_sender").set(ts) "federation_sender").set(ts)
events_processed_counter.inc(len(events)) events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels(
"federation_sender"
).inc(len(events_by_room))
event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels( synapse.metrics.event_processing_positions.labels(
"federation_sender").set(next_token) "federation_sender").set(next_token)
@ -300,6 +309,9 @@ class TransactionQueue(object):
Args: Args:
states (list(UserPresenceState)) states (list(UserPresenceState))
""" """
if not self.hs.config.use_presence:
# No-op if presence is disabled.
return
# First we queue up the new presence by user ID, so multiple presence # First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled # updates in quick successtion are correctly handled

View File

@ -106,7 +106,7 @@ class TransportLayerClient(object):
dest (str) dest (str)
room_id (str) room_id (str)
event_tuples (list) event_tuples (list)
limt (int) limit (int)
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Deferred: Results in a dict received from the remote homeserver.
@ -195,7 +195,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_membership_event(self, destination, room_id, user_id, membership): def make_membership_event(self, destination, room_id, user_id, membership, params):
"""Asks a remote server to build and sign us a membership event """Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs. Note that this does not append any events to any graphs.
@ -205,6 +205,8 @@ class TransportLayerClient(object):
room_id (str): room to join/leave room_id (str): room to join/leave
user_id (str): user to be joined/left user_id (str): user to be joined/left
membership (str): one of join/leave membership (str): one of join/leave
params (dict[str, str|Iterable[str]]): Query parameters to include in the
request.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -241,6 +243,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=params,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=20000, timeout=20000,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,

View File

@ -190,6 +190,41 @@ def _parse_auth_header(header_bytes):
class BaseFederationServlet(object): class BaseFederationServlet(object):
"""Abstract base class for federation servlet classes.
The servlet object should have a PATH attribute which takes the form of a regexp to
match against the request path (excluding the /federation/v1 prefix).
The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
the appropriate HTTP method. These methods have the signature:
on_<METHOD>(self, origin, content, query, **kwargs)
With arguments:
origin (unicode|None): The authenticated server_name of the calling server,
unless REQUIRE_AUTH is set to False and authentication failed.
content (unicode|None): decoded json body of the request. None if the
request was a GET.
query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
(ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
yet.
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
Returns:
Deferred[(int, object)|None]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
Raises:
SynapseError: to return an error code
Exception: other exceptions will be caught, logged, and a 500 will be
returned.
"""
REQUIRE_AUTH = True REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name): def __init__(self, handler, authenticator, ratelimiter, server_name):
@ -204,6 +239,18 @@ class BaseFederationServlet(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(func) @functools.wraps(func)
def new_func(request, *args, **kwargs): def new_func(request, *args, **kwargs):
""" A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
*args: unused?
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
Returns:
Deferred[(int, object)|None]: (response code, response object) as returned
by the callback method. None if the request has already been handled.
"""
content = None content = None
if request.method in ["PUT", "POST"]: if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types? # TODO: Handle other method types? other content types?
@ -214,10 +261,10 @@ class BaseFederationServlet(object):
except NoAuthenticationError: except NoAuthenticationError:
origin = None origin = None
if self.REQUIRE_AUTH: if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed") logger.warn("authenticate_request failed: missing authentication")
raise raise
except Exception: except Exception as e:
logger.exception("authenticate_request failed") logger.warn("authenticate_request failed: %s", e)
raise raise
if origin: if origin:
@ -384,9 +431,31 @@ class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, _content, query, context, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
_content (None): (GETs don't have bodies)
query (dict[bytes, list[bytes]]): Query params from the request.
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
Returns:
Deferred[(int, object)|None]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
"""
versions = query.get(b'ver')
if versions is not None:
supported_versions = [v.decode("utf-8") for v in versions]
else:
supported_versions = ["1"]
content = yield self.handler.on_make_join_request( content = yield self.handler.on_make_join_request(
origin, context, user_id, origin, context, user_id,
supported_versions=supported_versions,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))

View File

@ -23,6 +23,10 @@ from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -136,6 +140,12 @@ class ApplicationServicesHandler(object):
events_processed_counter.inc(len(events)) events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels(
"appservice_sender"
).inc(len(events_by_room))
event_processing_loop_counter.labels("appservice_sender").inc()
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"appservice_sender").set(now - ts) "appservice_sender").set(now - ts)
synapse.metrics.event_processing_last_ts.labels( synapse.metrics.event_processing_last_ts.labels(

View File

@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self._check_mau_limits() yield self.auth.check_auth_blocking(user_id)
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -734,7 +734,6 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self._check_mau_limits()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -743,6 +742,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True, user_id)
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
yield self.auth.check_auth_blocking(user_id)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -828,12 +828,26 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address): def delete_threepid(self, user_id, medium, address):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
Args:
user_id (str)
medium (str)
address (str)
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
unbind API.
"""
# 'Canonicalise' email addresses as per above # 'Canonicalise' email addresses as per above
if medium == 'email': if medium == 'email':
address = address.lower() address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
yield identity_handler.unbind_threepid( result = yield identity_handler.try_unbind_threepid(
user_id, user_id,
{ {
'medium': medium, 'medium': medium,
@ -841,10 +855,10 @@ class AuthHandler(BaseHandler):
}, },
) )
ret = yield self.store.user_delete_threepid( yield self.store.user_delete_threepid(
user_id, medium, address, user_id, medium, address,
) )
defer.returnValue(ret) defer.returnValue(result)
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
@ -907,19 +921,6 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Ensure that if mau blocking is enabled that invalid users cannot
log in.
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View File

@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler):
erase_data (bool): whether to GDPR-erase the user's data erase_data (bool): whether to GDPR-erase the user's data
Returns: Returns:
Deferred Deferred[bool]: True if identity server supports removing
threepids, otherwise False.
""" """
# FIXME: Theoretically there is a race here wherein user resets # FIXME: Theoretically there is a race here wherein user resets
# password using threepid. # password using threepid.
@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler):
# leave the user still active so they can try again. # leave the user still active so they can try again.
# Ideally we would prevent password resets and then do this in the # Ideally we would prevent password resets and then do this in the
# background thread. # background thread.
# This will be set to false if the identity server doesn't support
# unbinding
identity_server_supports_unbinding = True
threepids = yield self.store.user_get_threepids(user_id) threepids = yield self.store.user_get_threepids(user_id)
for threepid in threepids: for threepid in threepids:
try: try:
yield self._identity_handler.unbind_threepid( result = yield self._identity_handler.try_unbind_threepid(
user_id, user_id,
{ {
'medium': threepid['medium'], 'medium': threepid['medium'],
'address': threepid['address'], 'address': threepid['address'],
}, },
) )
identity_server_supports_unbinding &= result
except Exception: except Exception:
# Do we want this to be a fatal error or should we carry on? # Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server") logger.exception("Failed to remove threepid from ID server")
@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
self._start_user_parting() self._start_user_parting()
defer.returnValue(identity_server_supports_unbinding)
def _start_user_parting(self): def _start_user_parting(self):
""" """
Start the process that goes through the table of users Start the process that goes through the table of users

View File

@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError from synapse.api.errors import FederationDeniedError
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination

View File

@ -30,7 +30,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import (
KNOWN_ROOM_VERSIONS,
EventTypes,
Membership,
RejectedReason,
)
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
CodeMessageException, CodeMessageException,
@ -44,10 +49,15 @@ from synapse.crypto.event_signing import (
compute_event_signature, compute_event_signature,
) )
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import resolve_events_with_factory from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -86,6 +96,18 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self._send_events_to_master = (
ReplicationFederationSendEventsRestServlet.make_client(hs)
)
self._notify_user_membership_change = (
ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
)
self._clean_room_for_join_client = (
ReplicationCleanRoomRestServlet.make_client(hs)
)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
@ -269,8 +291,9 @@ class FederationHandler(BaseHandler):
ev_ids, get_prev_content=False, check_redacted=False ev_ids, get_prev_content=False, check_redacted=False
) )
room_version = yield self.store.get_room_version(pdu.room_id)
state_map = yield resolve_events_with_factory( state_map = yield resolve_events_with_factory(
state_groups, {pdu.event_id: pdu}, fetch room_version, state_groups, {pdu.event_id: pdu}, fetch
) )
state = (yield self.store.get_events(state_map.values())).values() state = (yield self.store.get_events(state_map.values())).values()
@ -922,6 +945,9 @@ class FederationHandler(BaseHandler):
joinee, joinee,
"join", "join",
content, content,
params={
"ver": KNOWN_ROOM_VERSIONS,
},
) )
# This shouldn't happen, because the RoomMemberHandler has a # This shouldn't happen, because the RoomMemberHandler has a
@ -1150,7 +1176,7 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)]) yield self.persist_events_and_notify([(event, context)])
defer.returnValue(event) defer.returnValue(event)
@ -1181,19 +1207,20 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)]) yield self.persist_events_and_notify([(event, context)])
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},): content={}, params=None):
origin, pdu = yield self.federation_client.make_membership_event( origin, pdu = yield self.federation_client.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
membership, membership,
content, content,
params=params,
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, pdu)
@ -1423,7 +1450,7 @@ class FederationHandler(BaseHandler):
event, context event, context
) )
yield self._persist_events( yield self.persist_events_and_notify(
[(event, context)], [(event, context)],
backfilled=backfilled, backfilled=backfilled,
) )
@ -1461,7 +1488,7 @@ class FederationHandler(BaseHandler):
], consumeErrors=True, ], consumeErrors=True,
)) ))
yield self._persist_events( yield self.persist_events_and_notify(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
@ -1549,7 +1576,7 @@ class FederationHandler(BaseHandler):
raise raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self._persist_events( yield self.persist_events_and_notify(
[ [
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
@ -1560,7 +1587,7 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
yield self._persist_events( yield self.persist_events_and_notify(
[(event, new_event_context)], [(event, new_event_context)],
) )
@ -1802,7 +1829,10 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
room_version = yield self.store.get_room_version(event.room_id)
new_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())], [list(local_view.values()), list(remote_view.values())],
event event
) )
@ -2288,7 +2318,7 @@ class FederationHandler(BaseHandler):
for revocation. for revocation.
""" """
try: try:
response = yield self.hs.get_simple_http_client().get_json( response = yield self.http_client.get_json(
url, url,
{"public_key": public_key} {"public_key": public_key}
) )
@ -2301,7 +2331,7 @@ class FederationHandler(BaseHandler):
raise AuthError(403, "Third party certificate was invalid") raise AuthError(403, "Third party certificate was invalid")
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_events(self, event_and_contexts, backfilled=False): def persist_events_and_notify(self, event_and_contexts, backfilled=False):
"""Persists events and tells the notifier/pushers about them, if """Persists events and tells the notifier/pushers about them, if
necessary. necessary.
@ -2313,14 +2343,21 @@ class FederationHandler(BaseHandler):
Returns: Returns:
Deferred Deferred
""" """
max_stream_id = yield self.store.persist_events( if self.config.worker_app:
event_and_contexts, yield self._send_events_to_master(
backfilled=backfilled, store=self.store,
) event_and_contexts=event_and_contexts,
backfilled=backfilled
)
else:
max_stream_id = yield self.store.persist_events(
event_and_contexts,
backfilled=backfilled,
)
if not backfilled: # Never notify for backfilled events if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts: for event, _ in event_and_contexts:
self._notify_persisted_event(event, max_stream_id) self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id): def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the """Checks to see if notifier/pushers should be notified about the
@ -2353,15 +2390,30 @@ class FederationHandler(BaseHandler):
extra_users=extra_users extra_users=extra_users
) )
logcontext.run_in_background( self.pusher_pool.on_new_notifications(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id, event_stream_id, max_stream_id,
) )
def _clean_room_for_join(self, room_id): def _clean_room_for_join(self, room_id):
return self.store.clean_room_for_join(room_id) """Called to clean up any data in DB for a given room, ready for the
server to join the room.
Args:
room_id (str)
"""
if self.config.worker_app:
return self._clean_room_for_join_client(room_id)
else:
return self.store.clean_room_for_join(room_id)
def user_joined_room(self, user, room_id): def user_joined_room(self, user, room_id):
"""Called when a new user has joined the room """Called when a new user has joined the room
""" """
return user_joined_room(self.distributor, user, room_id) if self.config.worker_app:
return self._notify_user_membership_change(
room_id=room_id,
user_id=user.to_string(),
change="joined",
)
else:
return user_joined_room(self.distributor, user, room_id)

View File

@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def unbind_threepid(self, mxid, threepid): def try_unbind_threepid(self, mxid, threepid):
""" """Removes a binding from an identity server
Removes a binding from an identity server
Args: Args:
mxid (str): Matrix user ID of binding to be removed mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be removed threepid (dict): Dict with medium & address of binding to be removed
Raises:
SynapseError: If we failed to contact the identity server
Returns: Returns:
Deferred[bool]: True on success, otherwise False Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding
""" """
logger.debug("unbinding threepid %r from %s", threepid, mxid) logger.debug("unbinding threepid %r from %s", threepid, mxid)
if not self.trusted_id_servers: if not self.trusted_id_servers:
@ -175,11 +179,21 @@ class IdentityHandler(BaseHandler):
content=content, content=content,
destination_is=id_server, destination_is=id_server,
) )
yield self.http_client.post_json_get_json( try:
url, yield self.http_client.post_json_get_json(
content, url,
headers, content,
) headers,
)
except HttpResponseException as e:
if e.code in (400, 404, 501,):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
defer.returnValue(False)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(502, "Failed to contact identity server")
defer.returnValue(True) defer.returnValue(True)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
defer.returnValue([])
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
[m.user_id for m in room_members], [m.user_id for m in room_members],
as_event=True, as_event=True,

View File

@ -25,17 +25,24 @@ from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
NotFoundError,
SynapseError,
)
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import send_event_to_master from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -82,28 +89,85 @@ class MessageHandler(object):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False): def get_state_events(
self, user_id, room_id, types=None, filtered_types=None,
at_token=None, is_guest=False,
):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. left the room return the state events from when they left. If an explicit
'at' parameter is passed, return the state events as of that event, if
visible.
Args: Args:
user_id(str): The user requesting state events. user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from. room_id(str): The room ID to get all state events from.
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
is_guest(bool): whether this user is a guest
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" Raises:
membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( NotFoundError (404) if the at token does not yield an event
room_id, user_id
)
if membership == Membership.JOIN: AuthError (403) if the user doesn't have permission to view
room_state = yield self.state.get_current_state(room_id) members of this room.
elif membership == Membership.LEAVE: """
room_state = yield self.store.get_state_for_events( if at_token:
[membership_event_id], None # FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1,
) )
room_state = room_state[membership_event_id]
if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token, ))
visible_events = yield filter_events_for_client(
self.store, user_id, last_events,
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
[event.event_id], types, filtered_types=filtered_types,
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s" % (
user_id, room_id, at_token,
)
)
else:
membership, membership_event_id = (
yield self.auth.check_in_room_or_world_readable(
room_id, user_id,
)
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
room_id, types, filtered_types=filtered_types,
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], types, filtered_types=filtered_types,
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(
@ -171,7 +235,7 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.config = hs.config self.config = hs.config
self.http_client = hs.get_simple_http_client() self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users # This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs) self.base_handler = BaseHandler(hs)
@ -212,10 +276,14 @@ class EventCreationHandler(object):
where *hashes* is a map from algorithm to hash. where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database. If None, they will be requested from the database.
Raises:
ResourceLimitError if server is blocked to some resource being
exceeded
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event (FrozenEvent), Context
""" """
yield self.auth.check_auth_blocking(requester.user.to_string())
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder) self.validator.validate_new(builder)
@ -559,12 +627,9 @@ class EventCreationHandler(object):
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker_app:
yield send_event_to_master( yield self.send_event_to_master(
clock=self.hs.get_clock(), event_id=event.event_id,
store=self.store, store=self.store,
client=self.http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester, requester=requester,
event=event, event=event,
context=context, context=context,
@ -713,11 +778,8 @@ class EventCreationHandler(object):
event, context=context event, context=context
) )
# this intentionally does not yield: we don't care about the result self.pusher_pool.on_new_notifications(
# and don't need to wait for it. event_stream_id, max_stream_id,
run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
) )
def _notify(): def _notify():

View File

@ -18,11 +18,11 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -251,6 +251,33 @@ class PaginationHandler(object):
is_peeking=(member_event_id is None), is_peeking=(member_event_id is None),
) )
state = None
if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members
types = [
(EventTypes.Member, state_key)
for state_key in set(
event.sender # FIXME: we also care about invite targets etc.
for event in events
)
]
state_ids = yield self.store.get_state_ids_for_event(
events[0].event_id, types=types,
)
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
if state:
state = yield filter_events_for_client(
self.store,
user_id,
state.values(),
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunk = { chunk = {
@ -262,4 +289,10 @@ class PaginationHandler(object):
"end": next_token.to_string(), "end": next_token.to_string(),
} }
if state:
chunk["state"] = [
serialize_event(e, time_now, as_client_event)
for e in state
]
defer.returnValue(chunk) defer.returnValue(chunk)

View File

@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -95,6 +95,7 @@ class PresenceHandler(object):
Args: Args:
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
self.hs = hs
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -230,6 +231,10 @@ class PresenceHandler(object):
earlier than they should when synapse is restarted. This affect of this earlier than they should when synapse is restarted. This affect of this
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
logger.info( logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes", "Performing _on_shutdown. Persisting %d unpersisted changes",
len(self.user_to_current_state) len(self.user_to_current_state)
@ -390,6 +395,10 @@ class PresenceHandler(object):
"""We've seen the user do something that indicates they're interacting """We've seen the user do something that indicates they're interacting
with the app. with the app.
""" """
# If presence is disabled, no-op
if not self.hs.config.use_presence:
return
user_id = user.to_string() user_id = user.to_string()
bump_active_time_counter.inc() bump_active_time_counter.inc()
@ -419,6 +428,11 @@ class PresenceHandler(object):
Useful for streams that are not associated with an actual Useful for streams that are not associated with an actual
client that is being used by a user. client that is being used by a user.
""" """
# Override if it should affect the user's presence, if presence is
# disabled.
if not self.hs.config.use_presence:
affect_presence = False
if affect_presence: if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0) curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1 self.user_to_num_current_syncs[user_id] = curr_sync + 1
@ -464,13 +478,16 @@ class PresenceHandler(object):
Returns: Returns:
set(str): A set of user_id strings. set(str): A set of user_id strings.
""" """
syncing_user_ids = { if self.hs.config.use_presence:
user_id for user_id, count in self.user_to_num_current_syncs.items() syncing_user_ids = {
if count user_id for user_id, count in self.user_to_num_current_syncs.items()
} if count
for user_ids in self.external_process_to_current_syncs.values(): }
syncing_user_ids.update(user_ids) for user_ids in self.external_process_to_current_syncs.values():
return syncing_user_ids syncing_user_ids.update(user_ids)
return syncing_user_ids
else:
return set()
@defer.inlineCallbacks @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):

View File

@ -32,12 +32,16 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler): class BaseProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000 """Handles fetching and updating user profile information.
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
BaseProfileHandler can be instantiated directly on workers and will
delegate to master when necessary. The master process should use the
subclass MasterProfileHandler
"""
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(BaseProfileHandler, self).__init__(hs)
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler( hs.get_federation_registry().register_query_handler(
@ -46,11 +50,6 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
if hs.config.worker_app is None:
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_profile(self, user_id): def get_profile(self, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -282,6 +281,20 @@ class ProfileHandler(BaseHandler):
room_id, str(e.message) room_id, str(e.message)
) )
class MasterProfileHandler(BaseProfileHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super(MasterProfileHandler, self).__init__(hs)
assert hs.config.worker_app is None
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
)
def _start_update_remote_profile_cache(self): def _start_update_remote_profile_cache(self):
return run_as_background_process( return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache, "Update remote profile", self._update_remote_profile_cache,

View File

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler from ._base import BaseHandler

View File

@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler from ._base import BaseHandler
@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r["room_id"] for r in receipts])) affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext(): self.notifier.on_new_event(
self.notifier.on_new_event( "receipt_key", max_batch_id, rooms=affected_room_ids
"receipt_key", max_batch_id, rooms=affected_room_ids )
) # Note that the min here shouldn't be relied upon to be accurate.
# Note that the min here shouldn't be relied upon to be accurate. self.hs.get_pusherpool().on_new_receipts(
self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids,
min_batch_id, max_batch_id, affected_room_ids )
)
defer.returnValue(True) defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -28,7 +28,7 @@ from synapse.api.errors import (
) )
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler from ._base import BaseHandler
@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield self._check_mau_limits()
yield self.auth.check_auth_blocking()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
need_register = True need_register = True
try: try:
@ -533,16 +534,3 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
action="join", action="join",
) )
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Do not accept registrations if monthly active user limits exceeded
and limiting is enabled
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise RegistrationError(
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -21,9 +21,17 @@ import math
import string import string
from collections import OrderedDict from collections import OrderedDict
from six import string_types
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.constants import (
DEFAULT_ROOM_VERSION,
KNOWN_ROOM_VERSIONS,
EventTypes,
JoinRules,
RoomCreationPreset,
)
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils from synapse.util import stringutils
@ -90,15 +98,34 @@ class RoomCreationHandler(BaseHandler):
Raises: Raises:
SynapseError if the room ID couldn't be stored, or something went SynapseError if the room ID couldn't be stored, or something went
horribly wrong. horribly wrong.
ResourceLimitError if server is blocked to some resource being
exceeded
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
self.auth.check_auth_blocking(user_id)
if not self.spam_checker.user_may_create_room(user_id): if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit: if ratelimit:
yield self.ratelimit(requester) yield self.ratelimit(requester)
room_version = config.get("room_version", DEFAULT_ROOM_VERSION)
if not isinstance(room_version, string_types):
raise SynapseError(
400,
"room_version must be a string",
Codes.BAD_JSON,
)
if room_version not in KNOWN_ROOM_VERSIONS:
raise SynapseError(
400,
"Your homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace: for wchar in string.whitespace:
if wchar in config["room_alias_name"]: if wchar in config["room_alias_name"]:
@ -184,6 +211,9 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version
room_member_handler = self.hs.get_room_member_handler() room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room( yield self._send_events_for_new_room(

View File

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache

View File

@ -30,7 +30,7 @@ import synapse.types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -344,6 +344,7 @@ class RoomMemberHandler(object):
latest_event_ids = ( latest_event_ids = (
event_id for (event_id, _, _) in prev_events_and_hashes event_id for (event_id, _, _) in prev_events_and_hashes
) )
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )

View File

@ -20,16 +20,24 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import ( from synapse.replication.http.membership import (
get_or_register_3pid_guest, ReplicationRegister3PIDGuestRestServlet as Repl3PID,
notify_user_membership_change, ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
remote_join, ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
remote_reject_invite, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomMemberWorkerHandler(RoomMemberHandler): class RoomMemberWorkerHandler(RoomMemberHandler):
def __init__(self, hs):
super(RoomMemberWorkerHandler, self).__init__(hs)
self._get_register_3pid_client = Repl3PID.make_client(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content): def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
@ -37,10 +45,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
ret = yield remote_join( ret = yield self._remote_join_client(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester, requester=requester,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
room_id=room_id, room_id=room_id,
@ -55,10 +60,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite """Implements RoomMemberHandler._remote_reject_invite
""" """
return remote_reject_invite( return self._remote_reject_client(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester, requester=requester,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
room_id=room_id, room_id=room_id,
@ -68,10 +70,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id): def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room
""" """
return notify_user_membership_change( return self._notify_change_client(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
user_id=target.to_string(), user_id=target.to_string(),
room_id=room_id, room_id=room_id,
change="joined", change="joined",
@ -80,10 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def _user_left_room(self, target, room_id): def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room
""" """
return notify_user_membership_change( return self._notify_change_client(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
user_id=target.to_string(), user_id=target.to_string(),
room_id=room_id, room_id=room_id,
change="left", change="left",
@ -92,10 +88,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
"""Implements RoomMemberHandler.get_or_register_3pid_guest """Implements RoomMemberHandler.get_or_register_3pid_guest
""" """
return get_or_register_3pid_guest( return self._get_register_3pid_client(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester, requester=requester,
medium=medium, medium=medium,
address=address, address=address,

View File

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -75,6 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"ephemeral", "ephemeral",
"account_data", "account_data",
"unread_notifications", "unread_notifications",
"summary",
])): ])):
__slots__ = [] __slots__ = []
@ -184,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
class SyncHandler(object): class SyncHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.hs_config = hs.config
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@ -191,6 +193,7 @@ class SyncHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs, "sync") self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.auth = hs.get_auth()
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
@ -198,19 +201,27 @@ class SyncHandler(object):
max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) )
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result. return an empty sync result.
Returns: Returns:
A Deferred SyncResult. Deferred[SyncResult]
""" """
return self.response_cache.wrap( # If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
yield self.auth.check_auth_blocking(user_id)
res = yield self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, since_token, timeout, full_state, sync_config, since_token, timeout, full_state,
) )
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@ -494,10 +505,142 @@ class SyncHandler(object):
state = {} state = {}
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks
def compute_summary(self, room_id, sync_config, batch, state, now_token):
""" Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
Args:
room_id(str):
sync_config(synapse.handlers.sync.SyncConfig):
batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
the room that will be sent to the user.
state(dict): dict of (type, state_key) -> Event as returned by
compute_state_delta
now_token(str): Token of the end of the current batch.
Returns:
A deferred dict describing the room summary
"""
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room(
room_id, end_token=now_token.room_key, limit=1,
)
if not last_events:
defer.returnValue(None)
return
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id, [
(EventTypes.Member, None),
(EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''),
]
)
member_ids = {
state_key: event_id
for (t, state_key), event_id in state_ids.iteritems()
if t == EventTypes.Member
}
name_id = state_ids.get((EventTypes.Name, ''))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
summary = {}
# FIXME: it feels very heavy to load up every single membership event
# just to calculate the counts.
member_events = yield self.store.get_events(member_ids.values())
joined_user_ids = []
invited_user_ids = []
for ev in member_events.values():
if ev.content.get("membership") == Membership.JOIN:
joined_user_ids.append(ev.state_key)
elif ev.content.get("membership") == Membership.INVITE:
invited_user_ids.append(ev.state_key)
# TODO: only send these when they change.
summary["m.joined_member_count"] = len(joined_user_ids)
summary["m.invited_member_count"] = len(invited_user_ids)
if name_id or canonical_alias_id:
defer.returnValue(summary)
# FIXME: order by stream ordering, not alphabetic
me = sync_config.user.to_string()
if (joined_user_ids or invited_user_ids):
summary['m.heroes'] = sorted(
[
user_id
for user_id in (joined_user_ids + invited_user_ids)
if user_id != me
]
)[0:5]
else:
summary['m.heroes'] = sorted(
[user_id for user_id in member_ids.keys() if user_id != me]
)[0:5]
if not sync_config.filter_collection.lazy_load_members():
defer.returnValue(summary)
# ensure we send membership events for heroes if needed
cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.get_lazy_loaded_members_cache(cache_key)
# track which members the client should already know about via LL:
# Ones which are already in state...
existing_members = set(
user_id for (typ, user_id) in state.keys()
if typ == EventTypes.Member
)
# ...or ones which are in the timeline...
for ev in batch.events:
if ev.type == EventTypes.Member:
existing_members.add(ev.state_key)
# ...and then ensure any missing ones get included in state.
missing_hero_event_ids = [
member_ids[hero_id]
for hero_id in summary['m.heroes']
if (
cache.get(hero_id) != member_ids[hero_id] and
hero_id not in existing_members
)
]
missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state:
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
defer.returnValue(summary)
def get_lazy_loaded_members_cache(self, cache_key):
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
self.lazy_loaded_members_cache[cache_key] = cache
else:
logger.debug("found LruCache for %r", cache_key)
return cache
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
full_state): full_state):
""" Works out the differnce in state between the start of the timeline """ Works out the difference in state between the start of the timeline
and the previous sync. and the previous sync.
Args: Args:
@ -511,7 +654,7 @@ class SyncHandler(object):
full_state(bool): Whether to force returning the full state. full_state(bool): Whether to force returning the full state.
Returns: Returns:
A deferred new event dictionary A deferred dict of (type, state_key) -> Event
""" """
# TODO(mjark) Check if the state events were received by the server # TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state # after the previous sync, since we need to include those state
@ -609,13 +752,7 @@ class SyncHandler(object):
if lazy_load_members and not include_redundant_members: if lazy_load_members and not include_redundant_members:
cache_key = (sync_config.user.to_string(), sync_config.device_id) cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.lazy_loaded_members_cache.get(cache_key) cache = self.get_lazy_loaded_members_cache(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
self.lazy_loaded_members_cache[cache_key] = cache
else:
logger.debug("found LruCache for %r", cache_key)
# if it's a new sync sequence, then assume the client has had # if it's a new sync sequence, then assume the client has had
# amnesia and doesn't want any recent lazy-loaded members # amnesia and doesn't want any recent lazy-loaded members
@ -724,7 +861,7 @@ class SyncHandler(object):
since_token is None and since_token is None and
sync_config.filter_collection.blocks_all_presence() sync_config.filter_collection.blocks_all_presence()
) )
if not block_all_presence_data: if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence( yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users sync_result_builder, newly_joined_rooms, newly_joined_users
) )
@ -1416,7 +1553,6 @@ class SyncHandler(object):
if events == [] and tags is None: if events == [] and tags is None:
return return
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config sync_config = sync_result_builder.sync_config
@ -1459,6 +1595,18 @@ class SyncHandler(object):
full_state=full_state full_state=full_state
) )
summary = {}
if (
sync_config.filter_collection.lazy_load_members() and
(
any(ev.type == EventTypes.Member for ev in batch.events) or
since_token is None
)
):
summary = yield self.compute_summary(
room_id, sync_config, batch, state, now_token
)
if room_builder.rtype == "joined": if room_builder.rtype == "joined":
unread_notifications = {} unread_notifications = {}
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
@ -1468,6 +1616,7 @@ class SyncHandler(object):
ephemeral=ephemeral, ephemeral=ephemeral,
account_data=account_data_events, account_data=account_data_events,
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
summary=summary,
) )
if room_sync or always_include: if room_sync or always_include:

View File

@ -119,6 +119,8 @@ class UserDirectoryHandler(object):
"""Called to update index of our local user profiles when they change """Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in. irrespective of any rooms the user may be in.
""" """
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
yield self.store.update_profile_in_user_dir( yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None, user_id, profile.display_name, profile.avatar_url, None,
) )
@ -127,6 +129,8 @@ class UserDirectoryHandler(object):
def handle_user_deactivated(self, user_id): def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated """Called when a user ID is deactivated
""" """
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
yield self.store.remove_from_user_dir(user_id) yield self.store.remove_from_user_dir(user_id)
yield self.store.remove_from_user_in_public_room(user_id) yield self.store.remove_from_user_in_public_room(user_id)

View File

@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint from synapse.http.endpoint import SpiderEndpoint
from synapse.util.async import add_timeout_to_deferred from synapse.util.async_helpers import add_timeout_to_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable

View File

@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination. # our record of an individual server which can be tried to reach a destination.
@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name):
return host, port return host, port
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
Args: Args:
reactor: Twisted reactor. reactor: Twisted reactor.
destination (bytes): The name of the server to connect to. destination (bytes): The name of the server to connect to.
ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory tls_client_options_factory
which generates SSL contexts to use for TLS. (synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds timeout (int): connection timeout in seconds
""" """
@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
if timeout is not None: if timeout is not None:
endpoint_kw_args.update(timeout=timeout) endpoint_kw_args.update(timeout=timeout)
if ssl_context_factory is None: if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint transport_endpoint = HostnameEndpoint
default_port = 8008 default_port = 8008
else: else:
def transport_endpoint(reactor, host, port, timeout): def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS( return wrapClientTLS(
ssl_context_factory, tls_client_options_factory.get_options(host),
HostnameEndpoint(reactor, host, port, timeout=timeout)) HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448 default_port = 8448

View File

@ -43,7 +43,7 @@ from synapse.api.errors import (
from synapse.http import cancelled_to_request_timed_out_error from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.async import add_timeout_to_deferred from synapse.util.async_helpers import add_timeout_to_deferred
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_server_context_factory = hs.tls_server_context_factory self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
ssl_context_factory=self.tls_server_context_factory tls_client_options_factory=self.tls_client_options_factory
) )
@ -133,7 +133,7 @@ class MatrixFederationHttpClient(object):
failures, connection failures, SSL failures.) failures, connection failures, SSL failures.)
""" """
if ( if (
self.hs.config.federation_domain_whitelist and self.hs.config.federation_domain_whitelist is not None and
destination not in self.hs.config.federation_domain_whitelist destination not in self.hs.config.federation_domain_whitelist
): ):
raise FederationDeniedError(destination) raise FederationDeniedError(destination)
@ -439,7 +439,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True, def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
timeout=None, ignore_backoff=False): timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
@ -447,7 +447,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to args (dict|None): A dictionary used to create query strings, defaults to
None. None.
timeout (int): How long to try (in ms) the destination for before timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will giving up. None indicates no timeout and that the request will
@ -702,6 +702,9 @@ def check_content_type_is_json(headers):
def encode_query_args(args): def encode_query_args(args):
if args is None:
return b""
encoded_args = {} encoded_args = {}
for k, vs in args.items(): for k, vs in args.items():
if isinstance(vs, string_types): if isinstance(vs, string_types):

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import threading
from prometheus_client.core import Counter, Histogram from prometheus_client.core import Counter, Histogram
@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter(
# The set of all in flight requests, set[RequestMetrics] # The set of all in flight requests, set[RequestMetrics]
_in_flight_requests = set() _in_flight_requests = set()
# Protects the _in_flight_requests set from concurrent accesss
_in_flight_requests_lock = threading.Lock()
def _get_in_flight_counts(): def _get_in_flight_counts():
"""Returns a count of all in flight requests by (method, server_name) """Returns a count of all in flight requests by (method, server_name)
@ -120,7 +124,8 @@ def _get_in_flight_counts():
""" """
# Cast to a list to prevent it changing while the Prometheus # Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics # thread is collecting metrics
reqs = list(_in_flight_requests) with _in_flight_requests_lock:
reqs = list(_in_flight_requests)
for rm in reqs: for rm in reqs:
rm.update_metrics() rm.update_metrics()
@ -154,10 +159,12 @@ class RequestMetrics(object):
# to the "in flight" metrics. # to the "in flight" metrics.
self._request_stats = self.start_context.get_resource_usage() self._request_stats = self.start_context.get_resource_usage()
_in_flight_requests.add(self) with _in_flight_requests_lock:
_in_flight_requests.add(self)
def stop(self, time_sec, request): def stop(self, time_sec, request):
_in_flight_requests.discard(self) with _in_flight_requests_lock:
_in_flight_requests.discard(self)
context = LoggingContext.current_context() context = LoggingContext.current_context()

View File

@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from twisted.web import resource, server from twisted.web import resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
import synapse.events import synapse.events
@ -37,10 +38,13 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def wrap_json_request_handler(h): def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling. """Wraps a request handler method with exception handling.
Also adds logging as per wrap_request_handler_with_logging. Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
The handler must return a deferred. If the deferred succeeds we assume that The handler must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use a response has been sent. If the deferred fails with a SynapseError we use
@ -108,24 +111,23 @@ def wrap_json_request_handler(h):
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
) )
return wrap_request_handler_with_logging(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
def wrap_html_request_handler(h): def wrap_html_request_handler(h):
"""Wraps a request handler method with exception handling. """Wraps a request handler method with exception handling.
Also adds logging as per wrap_request_handler_with_logging. Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
""" """
def wrapped_request_handler(self, request): def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request) d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request) d.addErrback(_return_html_error, request)
return d return d
return wrap_request_handler_with_logging(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
def _return_html_error(f, request): def _return_html_error(f, request):
@ -170,46 +172,26 @@ def _return_html_error(f, request):
finish_request(request) finish_request(request)
def wrap_request_handler_with_logging(h): def wrap_async_request_handler(h):
"""Wraps a request handler to provide logging and metrics """Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
is correctly recorded against the request metrics/logs.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
As well as calling `request.processing` (which will log the response and The handler may return a deferred, in which case the completion of the request isn't
duration for this request), the wrapped request handler will insert the logged until the deferred completes.
request id into the logging context.
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped_request_handler(self, request): def wrapped_async_request_handler(self, request):
""" with request.processing():
Args: yield h(self, request)
self:
request (synapse.http.site.SynapseRequest):
"""
request_id = request.get_request_id() # we need to preserve_fn here, because the synchronous render method won't yield for
with LoggingContext(request_id) as request_context: # us (obviously)
request_context.request = request_id return preserve_fn(wrapped_async_request_handler)
with Measure(self.clock, "wrapped_request_handler"):
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = self.__class__.__name__
with request.processing(servlet_name):
with PreserveLoggingContext(request_context):
d = defer.maybeDeferred(h, self, request)
# record the arrival of the request *after*
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(request.method,
request.request_metrics.name).inc()
yield d
return wrapped_request_handler
class HttpServer(object): class HttpServer(object):
@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This gets called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return return
if pretty_print: if pretty_print:
json_bytes = (encode_pretty_printed_json(json_object) + "\n" json_bytes = encode_pretty_printed_json(json_object) + b"\n"
).encode("utf-8")
else: else:
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes # canonicaljson already encodes to bytes
@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors: if send_cors:
set_cors_headers(request) set_cors_headers(request)
request.write(json_bytes) # todo: we can almost certainly avoid this copy and encode the json straight into
finish_request(request) # the bytesIO, but it would involve faffing around with string->bytes wrappers.
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
return NOT_DONE_YET return NOT_DONE_YET

View File

@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults default (int|None): value to use if the parameter is absent, defaults
to None. to None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
try: try:
return int(args[name][0]) return int(args[name][0])
@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults default (bool|None): value to use if the parameter is absent, defaults
to None. to None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
try: try:
return { return {
"true": True, b"true": True,
"false": False, b"false": False,
}[args[name][0]] }[args[name][0]]
except Exception: except Exception:
message = ( message = (
@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False, def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string", encoding='ascii'):
"""Parse a string parameter from the request query string. """
Parse a string parameter from the request query string.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (str|None): value to use if the parameter is absent, defaults default (bytes/unicode|None): value to use if the parameter is absent,
to None. defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
allowed_values (list[str]): List of allowed values for the string, allowed_values (list[bytes/unicode]): List of allowed values for the
or None if any value is allowed, defaults to None string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the name to, and decode the string
content with.
Returns: Returns:
str|None: A string value or the default. bytes/unicode|None: A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises: Raises:
SynapseError if the parameter is absent and required, or if the SynapseError if the parameter is absent and required, or if the
@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args( return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type, request.args, name, default, required, allowed_values, param_type, encoding
) )
def parse_string_from_args(args, name, default=None, required=False, def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string", encoding='ascii'):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
value = args[name][0] value = args[name][0]
if encoding:
value = value.decode(encoding)
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)
@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
if encoding and isinstance(default, bytes):
return default.decode(encoding)
return default return default

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import logging import logging
import time import time
@ -19,8 +18,8 @@ import time
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.util.logcontext import ContextResourceUsage, LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,25 +33,43 @@ class SynapseRequest(Request):
It extends twisted's twisted.web.server.Request, and adds: It extends twisted's twisted.web.server.Request, and adds:
* Unique request ID * Unique request ID
* A log context associated with the request
* Redaction of access_token query-params in __repr__ * Redaction of access_token query-params in __repr__
* Logging at start and end * Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint. * Metrics to record CPU, wallclock and DB time by endpoint.
It provides a method `processing` which should be called by the Resource It also provides a method `processing`, which returns a context manager. If this
which is handling the request, and returns a context manager. method is called, the request won't be logged until the context manager is closed;
this is useful for asynchronous request handlers which may go on processing the
request even after the client has disconnected.
Attributes:
logcontext(LoggingContext) : the log context for this request
""" """
def __init__(self, site, channel, *args, **kw): def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = site self.site = site
self._channel = channel self._channel = channel # this is used by the tests
self.authenticated_entity = None self.authenticated_entity = None
self.start_time = 0 self.start_time = 0
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None
global _next_request_seq global _next_request_seq
self.request_seq = _next_request_seq self.request_seq = _next_request_seq
_next_request_seq += 1 _next_request_seq += 1
# whether an asynchronous request handler has called processing()
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
def __repr__(self): def __repr__(self):
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
@ -74,11 +91,116 @@ class SynapseRequest(Request):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def render(self, resrc): def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
# create a LogContext for this request
request_id = self.get_request_id()
logcontext = self.logcontext = LoggingContext(request_id)
logcontext.request = request_id
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)
return Request.render(self, resrc)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = resrc.__class__.__name__
self._started_processing(servlet_name)
Request.render(self, resrc)
# record the arrival of the request *after*
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(self.method,
self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
Once the context manager is closed, the completion of the request will be logged,
and the various metrics will be updated.
"""
if self._is_processing:
raise RuntimeError("Request is already processing")
self._is_processing = True
try:
yield
except Exception:
# this should already have been caught, and sent back to the client as a 500.
logger.exception("Asynchronous messge handler raised an uncaught exception")
finally:
# the request handler has finished its work and either sent the whole response
# back, or handed over responsibility to a Producer.
self._processing_finished_time = time.time()
self._is_processing = False
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
if self.finish_time is not None:
self._finished_processing()
def finish(self):
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
logging.
"""
self.finish_time = time.time()
Request.finish(self)
if not self._is_processing:
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
do logging.
"""
self.finish_time = time.time()
Request.connectionLost(self, reason)
# we only get here if the connection to the client drops before we send
# the response.
#
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
logger.warn(
"Error processing request %r: %s %s", self, reason.type, reason.value,
)
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name): def _started_processing(self, servlet_name):
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
be sure to call finished_processing.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.name.
"""
self.start_time = time.time() self.start_time = time.time()
self.request_metrics = RequestMetrics() self.request_metrics = RequestMetrics()
self.request_metrics.start( self.request_metrics.start(
@ -94,18 +216,32 @@ class SynapseRequest(Request):
) )
def _finished_processing(self): def _finished_processing(self):
try: """Log the completion of this request and update the metrics
context = LoggingContext.current_context() """
usage = context.get_resource_usage()
except Exception:
usage = ContextResourceUsage()
end_time = time.time() if self.logcontext is None:
# this can happen if the connection closed before we read the
# headers (so render was never called). In that case we'll already
# have logged a warning, so just bail out.
return
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
# we completed the request without anything calling processing()
self._processing_finished_time = time.time()
# the time between receiving the request and the request handler finishing
processing_time = self._processing_finished_time - self.start_time
# the time between the request handler finishing and the response being sent
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes # need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header # from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity authenticated_entity = self.authenticated_entity
if authenticated_entity is not None: if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace") authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header. # ...or could be raw utf-8 bytes in the User-Agent header.
@ -116,22 +252,31 @@ class SynapseRequest(Request):
user_agent = self.get_user_agent() user_agent = self.get_user_agent()
if user_agent is not None: if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace") user_agent = user_agent.decode("utf-8", "replace")
else:
user_agent = "-"
code = str(self.code)
if not self.finished:
# we didn't send the full response before we gave up (presumably because
# the connection dropped)
code += "!"
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
authenticated_entity, authenticated_entity,
end_time - self.start_time, processing_time,
response_send_time,
usage.ru_utime, usage.ru_utime,
usage.ru_stime, usage.ru_stime,
usage.db_sched_duration_sec, usage.db_sched_duration_sec,
usage.db_txn_duration_sec, usage.db_txn_duration_sec,
int(usage.db_txn_count), int(usage.db_txn_count),
self.sentLength, self.sentLength,
self.code, code,
self.method, self.method,
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto, self.clientproto,
@ -140,38 +285,10 @@ class SynapseRequest(Request):
) )
try: try:
self.request_metrics.stop(end_time, self) self.request_metrics.stop(self.finish_time, self)
except Exception as e: except Exception as e:
logger.warn("Failed to stop metrics: %r", e) logger.warn("Failed to stop metrics: %r", e)
@contextlib.contextmanager
def processing(self, servlet_name):
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
This will log the request's arrival. Once the context manager is
closed, the completion of the request will be logged, and the various
metrics will be updated.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.servlet_name.
"""
# TODO: we should probably just move this into render() and finish(),
# to save having to call a separate method.
self._started_processing(servlet_name)
yield
self._finished_processing()
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
@ -217,7 +334,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False) proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied) self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string self.server_version_string = server_version_string.encode('ascii')
def log(self, request): def log(self, request):
pass pass

View File

@ -174,6 +174,19 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions
events_processed_counter = Counter("synapse_federation_client_events_processed", "") events_processed_counter = Counter("synapse_federation_client_events_processed", "")
event_processing_loop_counter = Counter(
"synapse_event_processing_loop_count",
"Event processing loop iterations",
["name"],
)
event_processing_loop_room_count = Counter(
"synapse_event_processing_loop_room_count",
"Rooms seen per event processing loop iteration",
["name"],
)
# Used to track where various components have processed in the event stream, # Used to track where various components have processed in the event stream,
# e.g. federation sending, appservice sending, etc. # e.g. federation sending, appservice sending, etc.
event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"]) event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"])

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import threading
import six import six
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
@ -78,6 +80,9 @@ _background_process_counts = dict() # type: dict[str, int]
# of process descriptions that no longer have any active processes. # of process descriptions that no longer have any active processes.
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] _background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
# A lock that covers the above dicts
_bg_metrics_lock = threading.Lock()
class _Collector(object): class _Collector(object):
"""A custom metrics collector for the background process metrics. """A custom metrics collector for the background process metrics.
@ -92,7 +97,11 @@ class _Collector(object):
labels=["name"], labels=["name"],
) )
for desc, processes in six.iteritems(_background_processes): # We copy the dict so that it doesn't change from underneath us
with _bg_metrics_lock:
_background_processes_copy = dict(_background_processes)
for desc, processes in six.iteritems(_background_processes_copy):
background_process_in_flight_count.add_metric( background_process_in_flight_count.add_metric(
(desc,), len(processes), (desc,), len(processes),
) )
@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def run(): def run():
count = _background_process_counts.get(desc, 0) with _bg_metrics_lock:
_background_process_counts[desc] = count + 1 count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc() _background_process_start_count.labels(desc).inc()
with LoggingContext(desc) as context: with LoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count) context.request = "%s-%i" % (desc, count)
proc = _BackgroundProcess(desc, context) proc = _BackgroundProcess(desc, context)
_background_processes.setdefault(desc, set()).add(proc)
with _bg_metrics_lock:
_background_processes.setdefault(desc, set()).add(proc)
try: try:
yield func(*args, **kwargs) yield func(*args, **kwargs)
finally: finally:
proc.update_metrics() proc.update_metrics()
_background_processes[desc].remove(proc)
with _bg_metrics_lock:
_background_processes[desc].remove(proc)
with PreserveLoggingContext(): with PreserveLoggingContext():
return run() return run()

View File

@ -25,7 +25,7 @@ from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.util.async import ( from synapse.util.async_helpers import (
DeferredTimeoutError, DeferredTimeoutError,
ObservableDeferred, ObservableDeferred,
add_timeout_to_deferred, add_timeout_to_deferred,

View File

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@ -35,7 +35,7 @@ from synapse.push.presentable_names import (
name_from_member_event, name_from_member_event,
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -18,6 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -122,8 +123,14 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'], p['app_id'], p['pushkey'], p['user_name'],
) )
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id): def on_new_notifications(self, min_stream_id, max_stream_id):
run_as_background_process(
"on_new_notifications",
self._on_new_notifications, min_stream_id, max_stream_id,
)
@defer.inlineCallbacks
def _on_new_notifications(self, min_stream_id, max_stream_id):
try: try:
users_affected = yield self.store.get_push_action_users_in_range( users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id min_stream_id, max_stream_id
@ -147,8 +154,14 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
run_as_background_process(
"on_new_receipts",
self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
)
@defer.inlineCallbacks
def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try: try:
# Need to subtract 1 from the minimum because the lower bound here # Need to subtract 1 from the minimum because the lower bound here
# is not inclusive # is not inclusive

View File

@ -39,7 +39,7 @@ REQUIREMENTS = {
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"], "Twisted>=17.1.0": ["twisted>=17.1.0"],
# We use crypto.get_elliptic_curve which is only supported in >=0.15 # We use crypto.get_elliptic_curve which is only supported in >=0.15
"pyopenssl>=0.15": ["OpenSSL>=0.15"], "pyopenssl>=0.15": ["OpenSSL>=0.15"],

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import membership, send_event from synapse.replication.http import federation, membership, send_event
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs): def register_servlets(self, hs):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
federation.register_servlets(hs, self)

View File

@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import re
from six.moves import urllib
from twisted.internet import defer
from synapse.api.errors import CodeMessageException, HttpResponseException
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class ReplicationEndpoint(object):
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
(with an `/:txn_id` prefix for cached requests.), where NAME is a name,
PATH_ARGS are a tuple of parameters to be encoded in the URL.
For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
with `CACHE` set to true then this generates an endpoint:
/_synapse/replication/send_event/:event_id/:txn_id
For POST/PUT requests the payload is serialized to json and sent as the
body, while for GET requests the payload is added as query parameters. See
`_serialize_payload` for details.
Incoming requests are handled by overriding `_handle_request`. Servers
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
in logging and metrics.
PATH_ARGS (tuple[str]): A list of parameters to be added to the path.
Adding parameters to the path (rather than payload) can make it
easier to follow along in the log files.
METHOD (str): The method of the HTTP request, defaults to POST. Can be
one of POST, PUT or GET. If GET then the payload is sent as query
parameters rather than a JSON body.
CACHE (bool): Whether server should cache the result of the request/
If true then transparently adds a txn_id to all requests, and
`_handle_request` must return a Deferred.
RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504
is received.
"""
__metaclass__ = abc.ABCMeta
NAME = abc.abstractproperty()
PATH_ARGS = abc.abstractproperty()
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
def __init__(self, hs):
if self.CACHE:
self.response_cache = ResponseCache(
hs, "repl." + self.NAME,
timeout_ms=30 * 60 * 1000,
)
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
kwargs) so that an appropriate exception is raised if the client is
called with unexpected parameters. All PATH_ARGS must appear in
argument list.
Returns:
Deferred[dict]|dict: If POST/PUT request then dictionary must be
JSON serialisable, otherwise must be appropriate for adding as
query args.
"""
return {}
@abc.abstractmethod
def _handle_request(self, request, **kwargs):
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
Deferred[dict]: A JSON serialisable dict to be used as response
body of request.
"""
pass
@classmethod
def make_client(cls, hs):
"""Create a client that makes requests.
Returns a callable that accepts the same parameters as `_serialize_payload`.
"""
clock = hs.get_clock()
host = hs.config.worker_replication_host
port = hs.config.worker_replication_http_port
client = hs.get_simple_http_client()
@defer.inlineCallbacks
def send_request(**kwargs):
data = yield cls._serialize_payload(**kwargs)
url_args = [urllib.parse.quote(kwargs[name]) for name in cls.PATH_ARGS]
if cls.CACHE:
txn_id = random_string(10)
url_args.append(txn_id)
if cls.METHOD == "POST":
request_func = client.post_json_get_json
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
request_func = client.get_json
else:
# We have already asserted in the constructor that a
# compatible was picked, but lets be paranoid.
raise Exception(
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
host, port, cls.NAME, "/".join(url_args)
)
try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
try:
result = yield request_func(uri, data)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise
logger.warn("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
yield clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
defer.returnValue(result)
return send_request
def register(self, http_server):
"""Called by the server to register this as a handler to the
appropriate path.
"""
url_args = list(self.PATH_ARGS)
handler = self._handle_request
method = self.METHOD
if self.CACHE:
handler = self._cached_handler
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (
self.NAME,
args
))
http_server.register_paths(method, [pattern], handler)
def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
"""
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.
assert self.CACHE
return self.response_cache.wrap(
txn_id,
self._handle_request,
request, **kwargs
)

View File

@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and
notifying.
The API looks like:
POST /_synapse/replication/fed_send_events/:txn_id
{
"events": [{
"event": { .. serialized event .. },
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
}],
"backfilled": false
"""
NAME = "fed_send_events"
PATH_ARGS = ()
def __init__(self, hs):
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
@defer.inlineCallbacks
def _serialize_payload(store, event_and_contexts, backfilled):
"""
Args:
store
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
"""
event_payloads = []
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
event_payloads.append({
"event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
})
payload = {
"events": event_payloads,
"backfilled": backfilled,
}
defer.returnValue(payload)
@defer.inlineCallbacks
def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
backfilled = content["backfilled"]
event_payloads = content["events"]
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
context = yield EventContext.deserialize(
self.store, event_payload["context"],
)
event_and_contexts.append((event, context))
logger.info(
"Got %d events from federation",
len(event_and_contexts),
)
yield self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled,
)
defer.returnValue((200, {}))
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
"""Handles EDUs newly received from federation, including persisting and
notifying.
Request format:
POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
{
"origin": ...,
"content: { ... }
}
"""
NAME = "fed_send_edu"
PATH_ARGS = ("edu_type",)
def __init__(self, hs):
super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.registry = hs.get_federation_registry()
@staticmethod
def _serialize_payload(edu_type, origin, content):
return {
"origin": origin,
"content": content,
}
@defer.inlineCallbacks
def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
origin = content["origin"]
edu_content = content["content"]
logger.info(
"Got %r edu from %s",
edu_type, origin,
)
result = yield self.registry.on_edu(edu_type, origin, edu_content)
defer.returnValue((200, result))
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
"""Handle responding to queries from federation.
Request format:
POST /_synapse/replication/fed_query/:query_type
{
"args": { ... }
}
"""
NAME = "fed_query"
PATH_ARGS = ("query_type",)
# This is a query, so let's not bother caching
CACHE = False
def __init__(self, hs):
super(ReplicationGetQueryRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.registry = hs.get_federation_registry()
@staticmethod
def _serialize_payload(query_type, args):
"""
Args:
query_type (str)
args (dict): The arguments received for the given query type
"""
return {
"args": args,
}
@defer.inlineCallbacks
def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
args = content["args"]
logger.info(
"Got %r query",
query_type,
)
result = yield self.registry.on_query(query_type, args)
defer.returnValue((200, result))
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Request format:
POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
{}
"""
NAME = "fed_cleanup_room"
PATH_ARGS = ("room_id",)
def __init__(self, hs):
super(ReplicationCleanRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
@staticmethod
def _serialize_payload(room_id, args):
"""
Args:
room_id (str)
"""
return {}
@defer.inlineCallbacks
def _handle_request(self, request, room_id):
yield self.store.clean_room_for_join(room_id)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)

View File

@ -14,182 +14,63 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import HttpResponseException from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
def remote_join(client, host, port, requester, remote_room_hosts, """Does a remote join for the given user to the given room
room_id, user_id, content):
"""Ask the master to do a remote join for the given user to the given room
Args: Request format:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
remote_room_hosts (list[str]): Servers to try and join via
room_id (str)
user_id (str)
content (dict): The event content to use for the join event
Returns: POST /_synapse/replication/remote_join/:room_id/:user_id
Deferred
"""
uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
payload = { {
"requester": requester.serialize(), "requester": ...,
"remote_room_hosts": remote_room_hosts, "remote_room_hosts": [...],
"room_id": room_id, "content": { ... }
"user_id": user_id, }
"content": content,
}
try:
result = yield client.post_json_get_json(uri, payload)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
defer.returnValue(result)
@defer.inlineCallbacks
def remote_reject_invite(client, host, port, requester, remote_room_hosts,
room_id, user_id):
"""Ask master to reject the invite for the user and room.
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
remote_room_hosts (list[str]): Servers to try and reject via
room_id (str)
user_id (str)
Returns:
Deferred
"""
uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
payload = {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"room_id": room_id,
"user_id": user_id,
}
try:
result = yield client.post_json_get_json(uri, payload)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
defer.returnValue(result)
@defer.inlineCallbacks
def get_or_register_3pid_guest(client, host, port, requester,
medium, address, inviter_user_id):
"""Ask the master to get/create a guest account for given 3PID.
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
""" """
uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port) NAME = "remote_join"
PATH_ARGS = ("room_id", "user_id",)
payload = {
"requester": requester.serialize(),
"medium": medium,
"address": address,
"inviter_user_id": inviter_user_id,
}
try:
result = yield client.post_json_get_json(uri, payload)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
defer.returnValue(result)
@defer.inlineCallbacks
def notify_user_membership_change(client, host, port, user_id, room_id, change):
"""Notify master that a user has joined or left the room
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication.
user_id (str)
room_id (str)
change (str): Either "join" or "left"
Returns:
Deferred
"""
assert change in ("joined", "left")
uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
payload = {
"user_id": user_id,
"room_id": room_id,
}
try:
result = yield client.post_json_get_json(uri, payload)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
defer.returnValue(result)
class ReplicationRemoteJoinRestServlet(RestServlet):
PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
def __init__(self, hs): def __init__(self, hs):
super(ReplicationRemoteJoinRestServlet, self).__init__() super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts,
content):
"""
Args:
requester(Requester)
room_id (str)
user_id (str)
remote_room_hosts (list[str]): Servers to try and join via
content(dict): The event content to use for the join event
"""
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"content": content,
}
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"] remote_room_hosts = content["remote_room_hosts"]
room_id = content["room_id"]
user_id = content["user_id"]
event_content = content["content"] event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
@ -212,23 +93,48 @@ class ReplicationRemoteJoinRestServlet(RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ReplicationRemoteRejectInviteRestServlet(RestServlet): class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")] """Rejects the invite for the user and room.
Request format:
POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
{
"requester": ...,
"remote_room_hosts": [...],
}
"""
NAME = "remote_reject_invite"
PATH_ARGS = ("room_id", "user_id",)
def __init__(self, hs): def __init__(self, hs):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__() super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
"""
Args:
requester(Requester)
room_id (str)
user_id (str)
remote_room_hosts (list[str]): Servers to try and reject via
"""
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
}
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"] remote_room_hosts = content["remote_room_hosts"]
room_id = content["room_id"]
user_id = content["user_id"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
@ -264,18 +170,50 @@ class ReplicationRemoteRejectInviteRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class ReplicationRegister3PIDGuestRestServlet(RestServlet): class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")] """Gets/creates a guest account for given 3PID.
Request format:
POST /_synapse/replication/get_or_register_3pid_guest/
{
"requester": ...,
"medium": ...,
"address": ...,
"inviter_user_id": ...
}
"""
NAME = "get_or_register_3pid_guest"
PATH_ARGS = ()
def __init__(self, hs): def __init__(self, hs):
super(ReplicationRegister3PIDGuestRestServlet, self).__init__() super(ReplicationRegister3PIDGuestRestServlet, self).__init__(hs)
self.registeration_handler = hs.get_handlers().registration_handler self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod
def _serialize_payload(requester, medium, address, inviter_user_id):
"""
Args:
requester(Requester)
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
"""
return {
"requester": requester.serialize(),
"medium": medium,
"address": address,
"inviter_user_id": inviter_user_id,
}
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def _handle_request(self, request):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
medium = content["medium"] medium = content["medium"]
@ -296,23 +234,41 @@ class ReplicationRegister3PIDGuestRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class ReplicationUserJoinedLeftRoomRestServlet(RestServlet): class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")] """Notifies that a user has joined or left the room
Request format:
POST /_synapse/replication/membership_change/:room_id/:user_id/:change
{}
"""
NAME = "membership_change"
PATH_ARGS = ("room_id", "user_id", "change")
CACHE = False # No point caching as should return instantly.
def __init__(self, hs): def __init__(self, hs):
super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__() super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
self.registeration_handler = hs.get_handlers().registration_handler self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
def on_POST(self, request, change): @staticmethod
content = parse_json_object_from_request(request) def _serialize_payload(room_id, user_id, change):
"""
Args:
room_id (str)
user_id (str)
change (str): Either "joined" or "left"
"""
assert change in ("joined", "left",)
user_id = content["user_id"] return {}
room_id = content["room_id"]
def _handle_request(self, request, room_id, user_id, change):
logger.info("user membership change: %s in %s", user_id, room_id) logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View File

@ -14,86 +14,26 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import CodeMessageException, HttpResponseException
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks class ReplicationSendEventRestServlet(ReplicationEndpoint):
def send_event_to_master(clock, store, client, host, port, requester, event, context,
ratelimit, extra_users):
"""Send event to be handled on the master
Args:
clock (synapse.util.Clock)
store (DataStore)
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
host, port, event.event_id,
)
serialized_context = yield context.serialize(event, store)
payload = {
"event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users],
}
try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
try:
result = yield client.put_json(uri, payload)
break
except CodeMessageException as e:
if e.code != 504:
raise
logger.warn("send_event request timed out")
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
yield clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
defer.returnValue(result)
class ReplicationSendEventRestServlet(RestServlet):
"""Handles events newly created on workers, including persisting and """Handles events newly created on workers, including persisting and
notifying. notifying.
The API looks like: The API looks like:
POST /_synapse/replication/send_event/:event_id POST /_synapse/replication/send_event/:event_id/:txn_id
{ {
"event": { .. serialized event .. }, "event": { .. serialized event .. },
@ -105,27 +45,47 @@ class ReplicationSendEventRestServlet(RestServlet):
"extra_users": [], "extra_users": [],
} }
""" """
PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")] NAME = "send_event"
PATH_ARGS = ("event_id",)
def __init__(self, hs): def __init__(self, hs):
super(ReplicationSendEventRestServlet, self).__init__() super(ReplicationSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
# The responses are tiny, so we may as well cache them for a while @staticmethod
self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000) @defer.inlineCallbacks
def _serialize_payload(event_id, store, event, context, requester,
ratelimit, extra_users):
"""
Args:
event_id (str)
store (DataStore)
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
def on_PUT(self, request, event_id): serialized_context = yield context.serialize(event, store)
return self.response_cache.wrap(
event_id, payload = {
self._handle_request, "event": event.get_pdu_json(),
request "internal_metadata": event.internal_metadata.get_dict(),
) "rejected_reason": event.rejected_reason,
"context": serialized_context,
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users],
}
defer.returnValue(payload)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_request(self, request): def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"): with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View File

@ -44,8 +44,8 @@ class SlavedEventStore(EventFederationWorkerStore,
RoomMemberWorkerStore, RoomMemberWorkerStore,
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
StreamWorkerStore, StreamWorkerStore,
EventsWorkerStore,
StateGroupWorkerStore, StateGroupWorkerStore,
EventsWorkerStore,
SignatureWorkerStore, SignatureWorkerStore,
UserErasureWorkerStore, UserErasureWorkerStore,
BaseSlavedStore): BaseSlavedStore):

View File

@ -13,19 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore from synapse.storage.transactions import TransactionStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
class TransactionStore(BaseSlavedStore): class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[ pass
"get_destination_retry_timings"
]
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
prep_send_transaction = DataStore.prep_send_transaction.__func__
delivered_txn = DataStore.delivered_txn.__func__

View File

@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
logger.info("Received rdata %s -> %s", stream_name, token) logger.info("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows) return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token): def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes """Called when we get new position data. By default this just pokes
@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
self.store.process_replication_rows(stream_name, token, []) return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data): def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting """When we received a SYNC we wake up any deferreds that were waiting

View File

@ -59,6 +59,12 @@ class Command(object):
""" """
return self.data return self.data
def get_logcontext_id(self):
"""Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name.
return self.NAME
class ServerCommand(Command): class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name. """Sent by the server on new connection and includes the server_name.
@ -116,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row), _json_encoder.encode(self.row),
)) ))
def get_logcontext_id(self):
return "RDATA-" + self.stream_name
class PositionCommand(Command): class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without """Sent by the client to tell the client the stream postition without
@ -190,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self): def to_line(self):
return " ".join((self.stream_name, str(self.token),)) return " ".join((self.stream_name, str(self.token),))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
class UserSyncCommand(Command): class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or """Sent by the client to inform the server that a user has started or

View File

@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from .commands import ( from .commands import (
@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function # Now lets try and call on_<CMD_NAME> function
try: try:
getattr(self, "on_%s" % (cmd_name,))(cmd) run_as_background_process(
"replication-" + cmd.get_logcontext_id(),
getattr(self, "on_%s" % (cmd_name,)),
cmd,
)
except Exception: except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line) logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data self.name = cmd.data
def on_USER_SYNC(self, cmd): def on_USER_SYNC(self, cmd):
self.streamer.on_user_sync( return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
) )
@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL": if stream_name == "ALL":
# Subscribe to all streams we're publishing to. # Subscribe to all streams we're publishing to.
for stream in iterkeys(self.streamer.streams_by_name): deferreds = [
self.subscribe_to_stream(stream, token) run_in_background(
self.subscribe_to_stream,
stream, token,
)
for stream in iterkeys(self.streamer.streams_by_name)
]
return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else: else:
self.subscribe_to_stream(stream_name, token) return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd): def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token) return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd): def on_REMOVE_PUSHER(self, cmd):
self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) return self.streamer.on_remove_pusher(
cmd.app_id, cmd.push_key, cmd.user_id,
)
def on_INVALIDATE_CACHE(self, cmd): def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd): def on_USER_IP(self, cmd):
self.streamer.on_user_ip( return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen, cmd.last_seen,
) )
@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows)
self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token) return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd): def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data) return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token): def replicate(self, stream_name, token):
"""Send the subscription request to the server """Send the subscription request to the server

View File

@ -17,7 +17,7 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,7 +53,7 @@ class HttpTransactionCache(object):
str: A transaction key str: A transaction key
""" """
token = self.auth.get_access_token_from_request(request) token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token return request.path.decode('utf8') + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs): def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts """A helper function for fetch_or_execute which extracts

View File

@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self._deactivate_account_handler.deactivate_account( result = yield self._deactivate_account_handler.deactivate_account(
target_user_id, erase, target_user_id, erase,
) )
defer.returnValue((200, {})) if result:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
defer.returnValue((200, {
"id_server_unbind_result": id_server_unbind_result,
}))
class ShutdownRoomRestServlet(ClientV1RestServlet): class ShutdownRoomRestServlet(ClientV1RestServlet):

Some files were not shown because too many files have changed in this diff Show More