Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_federation

This commit is contained in:
Erik Johnston 2018-08-09 10:16:16 +01:00
commit 5785b93711
67 changed files with 1156 additions and 408 deletions

View File

@ -35,3 +35,4 @@ recursive-include changelog.d *
prune .github prune .github
prune demo/etc prune demo/etc
prune docker

View File

@ -157,12 +157,19 @@ if you prefer.
In case of problems, please see the _`Troubleshooting` section below. In case of problems, please see the _`Troubleshooting` section below.
There is an offical synapse image available at https://hub.docker.com/r/matrixdotorg/synapse/tags/ which can be used with the docker-compose file available at `contrib/docker`. Further information on this including configuration options is available in `contrib/docker/README.md`. There is an offical synapse image available at
https://hub.docker.com/r/matrixdotorg/synapse/tags/ which can be used with
the docker-compose file available at `contrib/docker <contrib/docker>`_. Further information on
this including configuration options is available in the README on
hub.docker.com.
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Dockerfile to automate a synapse server in a single Docker image, at https://hub.docker.com/r/avhost/docker-matrix/tags/ Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at
https://hub.docker.com/r/avhost/docker-matrix/tags/
Also, Martin Giess has created an auto-deployment process with vagrant/ansible, Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy tested with VirtualBox/AWS/DigitalOcean - see
https://github.com/EMnify/matrix-synapse-auto-deploy
for details. for details.
Configuring synapse Configuring synapse

1
changelog.d/3585.bugfix Normal file
View File

@ -0,0 +1 @@
Respond with M_NOT_FOUND when profiles are not found locally or over federation. Fixes #3585

View File

@ -0,0 +1 @@
Refactor HTTP replication endpoints to reduce code duplication

1
changelog.d/3633.feature Normal file
View File

@ -0,0 +1 @@
Add ability to limit number of monthly active users on the server

1
changelog.d/3644.misc Normal file
View File

@ -0,0 +1 @@
Refactor location of docker build script.

1
changelog.d/3647.misc Normal file
View File

@ -0,0 +1 @@
Tests now correctly execute on Python 3.

1
changelog.d/3654.feature Normal file
View File

@ -0,0 +1 @@
Basic support for room versioning

1
changelog.d/3658.bugfix Normal file
View File

@ -0,0 +1 @@
Fix occasional glitches in the synapse_event_persisted_position metric

1
changelog.d/3662.feature Normal file
View File

@ -0,0 +1 @@
Ability to whitelist specific threepids against monthly active user limiting

1
changelog.d/3664.feature Normal file
View File

@ -0,0 +1 @@
Add some metrics for the appservice and federation event sending loops

View File

@ -1,23 +1,5 @@
# Synapse Docker # Synapse Docker
The `matrixdotorg/synapse` Docker image will run Synapse as a single process. It does not provide a
database server or a TURN server, you should run these separately.
If you run a Postgres server, you should simply include it in the same Compose
project or set the proper environment variables and the image will automatically
use that server.
## Build
Build the docker image with the `docker-compose build` command.
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
## Run
This image is designed to run either with an automatically generated configuration
file or with a custom configuration that requires manual edition.
### Automated configuration ### Automated configuration
It is recommended that you use Docker Compose to run your containers, including It is recommended that you use Docker Compose to run your containers, including
@ -54,94 +36,6 @@ Then, customize your configuration and run the server:
docker-compose up -d docker-compose up -d
``` ```
### Without Compose ### More information
If you do not wish to use Compose, you may still run this image using plain For more information on required environment variables and mounts, see the main docker documentation at [/docker/README.md](../../docker/README.md)
Docker commands. Note that the following is just a guideline and you may need
to add parameters to the docker run command to account for the network situation
with your postgres database.
```
docker run \
-d \
--name synapse \
-v ${DATA_PATH}:/data \
-e SYNAPSE_SERVER_NAME=my.matrix.host \
-e SYNAPSE_REPORT_STATS=yes \
docker.io/matrixdotorg/synapse:latest
```
## Volumes
The image expects a single volume, located at ``/data``, that will hold:
* temporary files during uploads;
* uploaded media and thumbnails;
* the SQLite database if you do not configure postgres;
* the appservices configuration.
You are free to use separate volumes depending on storage endpoints at your
disposal. For instance, ``/data/media`` coud be stored on a large but low
performance hdd storage while other files could be stored on high performance
endpoints.
In order to setup an application service, simply create an ``appservices``
directory in the data volume and write the application service Yaml
configuration file there. Multiple application services are supported.
## Environment
Unless you specify a custom path for the configuration file, a very generic
file will be generated, based on the following environment settings.
These are a good starting point for setting up your own deployment.
Global settings:
* ``UID``, the user id Synapse will run as [default 991]
* ``GID``, the group id Synapse will run as [default 991]
* ``SYNAPSE_CONFIG_PATH``, path to a custom config file
If ``SYNAPSE_CONFIG_PATH`` is set, you should generate a configuration file
then customize it manually. No other environment variable is required.
Otherwise, a dynamic configuration file will be used. The following environment
variables are available for configuration:
* ``SYNAPSE_SERVER_NAME`` (mandatory), the current server public hostname.
* ``SYNAPSE_REPORT_STATS``, (mandatory, ``yes`` or ``no``), enable anonymous
statistics reporting back to the Matrix project which helps us to get funding.
* ``SYNAPSE_NO_TLS``, set this variable to disable TLS in Synapse (use this if
you run your own TLS-capable reverse proxy).
* ``SYNAPSE_ENABLE_REGISTRATION``, set this variable to enable registration on
the Synapse instance.
* ``SYNAPSE_ALLOW_GUEST``, set this variable to allow guest joining this server.
* ``SYNAPSE_EVENT_CACHE_SIZE``, the event cache size [default `10K`].
* ``SYNAPSE_CACHE_FACTOR``, the cache factor [default `0.5`].
* ``SYNAPSE_RECAPTCHA_PUBLIC_KEY``, set this variable to the recaptcha public
key in order to enable recaptcha upon registration.
* ``SYNAPSE_RECAPTCHA_PRIVATE_KEY``, set this variable to the recaptcha private
key in order to enable recaptcha upon registration.
* ``SYNAPSE_TURN_URIS``, set this variable to the coma-separated list of TURN
uris to enable TURN for this homeserver.
* ``SYNAPSE_TURN_SECRET``, set this to the TURN shared secret if required.
Shared secrets, that will be initialized to random values if not set:
* ``SYNAPSE_REGISTRATION_SHARED_SECRET``, secret for registrering users if
registration is disable.
* ``SYNAPSE_MACAROON_SECRET_KEY`` secret for signing access tokens
to the server.
Database specific values (will use SQLite if not set):
* `POSTGRES_DB` - The database name for the synapse postgres database. [default: `synapse`]
* `POSTGRES_HOST` - The host of the postgres database if you wish to use postgresql instead of sqlite3. [default: `db` which is useful when using a container on the same docker network in a compose file where the postgres service is called `db`]
* `POSTGRES_PASSWORD` - The password for the synapse postgres database. **If this is set then postgres will be used instead of sqlite3.** [default: none] **NOTE**: You are highly encouraged to use postgresql! Please use the compose file to make it easier to deploy.
* `POSTGRES_USER` - The user for the synapse postgres database. [default: `matrix`]
Mail server specific values (will not send emails if not set):
* ``SYNAPSE_SMTP_HOST``, hostname to the mail server.
* ``SYNAPSE_SMTP_PORT``, TCP port for accessing the mail server [default ``25``].
* ``SYNAPSE_SMTP_USER``, username for authenticating against the mail server if any.
* ``SYNAPSE_SMTP_PASSWORD``, password for authenticating against the mail server if any.

View File

@ -54,7 +54,7 @@
"gnetId": null, "gnetId": null,
"graphTooltip": 0, "graphTooltip": 0,
"id": null, "id": null,
"iteration": 1533026624326, "iteration": 1533598785368,
"links": [ "links": [
{ {
"asDropdown": true, "asDropdown": true,
@ -4629,7 +4629,7 @@
"h": 9, "h": 9,
"w": 12, "w": 12,
"x": 0, "x": 0,
"y": 11 "y": 29
}, },
"id": 67, "id": 67,
"legend": { "legend": {
@ -4655,11 +4655,11 @@
"steppedLine": false, "steppedLine": false,
"targets": [ "targets": [
{ {
"expr": " synapse_event_persisted_position{instance=\"$instance\"} - ignoring(index, job, name) group_right(instance) synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", "expr": " synapse_event_persisted_position{instance=\"$instance\",job=\"synapse\"} - ignoring(index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
"format": "time_series", "format": "time_series",
"interval": "", "interval": "",
"intervalFactor": 1, "intervalFactor": 1,
"legendFormat": "{{job}}-{{index}}", "legendFormat": "{{job}}-{{index}} ",
"refId": "A" "refId": "A"
} }
], ],
@ -4697,7 +4697,11 @@
"min": null, "min": null,
"show": true "show": true
} }
] ],
"yaxis": {
"align": false,
"alignLevel": null
}
}, },
{ {
"aliasColors": {}, "aliasColors": {},
@ -4710,7 +4714,7 @@
"h": 9, "h": 9,
"w": 12, "w": 12,
"x": 12, "x": 12,
"y": 11 "y": 29
}, },
"id": 71, "id": 71,
"legend": { "legend": {
@ -4778,7 +4782,11 @@
"min": null, "min": null,
"show": true "show": true
} }
] ],
"yaxis": {
"align": false,
"alignLevel": null
}
} }
], ],
"title": "Event processing loop positions", "title": "Event processing loop positions",
@ -4957,5 +4965,5 @@
"timezone": "", "timezone": "",
"title": "Synapse", "title": "Synapse",
"uid": "000000012", "uid": "000000012",
"version": 125 "version": 127
} }

View File

@ -22,7 +22,7 @@ RUN cd /synapse \
setuptools \ setuptools \
&& mkdir -p /synapse/cache \ && mkdir -p /synapse/cache \
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \ && pip install -f /synapse/cache --upgrade --process-dependency-links . \
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \ && mv /synapse/docker/start.py /synapse/docker/conf / \
&& rm -rf \ && rm -rf \
setup.cfg \ setup.cfg \
setup.py \ setup.py \

124
docker/README.md Normal file
View File

@ -0,0 +1,124 @@
# Synapse Docker
This Docker image will run Synapse as a single process. It does not provide a database
server or a TURN server, you should run these separately.
## Run
We do not currently offer a `latest` image, as this has somewhat undefined semantics.
We instead release only tagged versions so upgrading between releases is entirely
within your control.
### Using docker-compose (easier)
This image is designed to run either with an automatically generated configuration
file or with a custom configuration that requires manual editing.
An easy way to make use of this image is via docker-compose. See the
[contrib/docker](../contrib/docker)
section of the synapse project for examples.
### Without Compose (harder)
If you do not wish to use Compose, you may still run this image using plain
Docker commands. Note that the following is just a guideline and you may need
to add parameters to the docker run command to account for the network situation
with your postgres database.
```
docker run \
-d \
--name synapse \
-v ${DATA_PATH}:/data \
-e SYNAPSE_SERVER_NAME=my.matrix.host \
-e SYNAPSE_REPORT_STATS=yes \
docker.io/matrixdotorg/synapse:latest
```
## Volumes
The image expects a single volume, located at ``/data``, that will hold:
* temporary files during uploads;
* uploaded media and thumbnails;
* the SQLite database if you do not configure postgres;
* the appservices configuration.
You are free to use separate volumes depending on storage endpoints at your
disposal. For instance, ``/data/media`` coud be stored on a large but low
performance hdd storage while other files could be stored on high performance
endpoints.
In order to setup an application service, simply create an ``appservices``
directory in the data volume and write the application service Yaml
configuration file there. Multiple application services are supported.
## Environment
Unless you specify a custom path for the configuration file, a very generic
file will be generated, based on the following environment settings.
These are a good starting point for setting up your own deployment.
Global settings:
* ``UID``, the user id Synapse will run as [default 991]
* ``GID``, the group id Synapse will run as [default 991]
* ``SYNAPSE_CONFIG_PATH``, path to a custom config file
If ``SYNAPSE_CONFIG_PATH`` is set, you should generate a configuration file
then customize it manually. No other environment variable is required.
Otherwise, a dynamic configuration file will be used. The following environment
variables are available for configuration:
* ``SYNAPSE_SERVER_NAME`` (mandatory), the current server public hostname.
* ``SYNAPSE_REPORT_STATS``, (mandatory, ``yes`` or ``no``), enable anonymous
statistics reporting back to the Matrix project which helps us to get funding.
* ``SYNAPSE_NO_TLS``, set this variable to disable TLS in Synapse (use this if
you run your own TLS-capable reverse proxy).
* ``SYNAPSE_ENABLE_REGISTRATION``, set this variable to enable registration on
the Synapse instance.
* ``SYNAPSE_ALLOW_GUEST``, set this variable to allow guest joining this server.
* ``SYNAPSE_EVENT_CACHE_SIZE``, the event cache size [default `10K`].
* ``SYNAPSE_CACHE_FACTOR``, the cache factor [default `0.5`].
* ``SYNAPSE_RECAPTCHA_PUBLIC_KEY``, set this variable to the recaptcha public
key in order to enable recaptcha upon registration.
* ``SYNAPSE_RECAPTCHA_PRIVATE_KEY``, set this variable to the recaptcha private
key in order to enable recaptcha upon registration.
* ``SYNAPSE_TURN_URIS``, set this variable to the coma-separated list of TURN
uris to enable TURN for this homeserver.
* ``SYNAPSE_TURN_SECRET``, set this to the TURN shared secret if required.
Shared secrets, that will be initialized to random values if not set:
* ``SYNAPSE_REGISTRATION_SHARED_SECRET``, secret for registrering users if
registration is disable.
* ``SYNAPSE_MACAROON_SECRET_KEY`` secret for signing access tokens
to the server.
Database specific values (will use SQLite if not set):
* `POSTGRES_DB` - The database name for the synapse postgres database. [default: `synapse`]
* `POSTGRES_HOST` - The host of the postgres database if you wish to use postgresql instead of sqlite3. [default: `db` which is useful when using a container on the same docker network in a compose file where the postgres service is called `db`]
* `POSTGRES_PASSWORD` - The password for the synapse postgres database. **If this is set then postgres will be used instead of sqlite3.** [default: none] **NOTE**: You are highly encouraged to use postgresql! Please use the compose file to make it easier to deploy.
* `POSTGRES_USER` - The user for the synapse postgres database. [default: `matrix`]
Mail server specific values (will not send emails if not set):
* ``SYNAPSE_SMTP_HOST``, hostname to the mail server.
* ``SYNAPSE_SMTP_PORT``, TCP port for accessing the mail server [default ``25``].
* ``SYNAPSE_SMTP_USER``, username for authenticating against the mail server if any.
* ``SYNAPSE_SMTP_PASSWORD``, password for authenticating against the mail server if any.
## Build
Build the docker image with the `docker build` command from the root of the synapse repository.
```
docker build -t docker.io/matrixdotorg/synapse . -f docker/Dockerfile
```
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
You may have a local Python wheel cache available, in which case copy the relevant
packages in the ``cache/`` directory at the root of the project.

View File

@ -213,7 +213,7 @@ class Auth(object):
default=[b""] default=[b""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -773,3 +773,15 @@ class Auth(object):
raise AuthError( raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
) )
@defer.inlineCallbacks
def check_auth_blocking(self):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -94,3 +95,11 @@ class RoomCreationPreset(object):
class ThirdPartyEntityKind(object): class ThirdPartyEntityKind(object):
USER = "user" USER = "user"
LOCATION = "location" LOCATION = "location"
# the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = "1"
# vdh-test-version is a placeholder to get room versioning support working and tested
# until we have a working v2.
KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"}

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -56,6 +57,8 @@ class Codes(object):
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
@ -285,6 +288,27 @@ class LimitExceededError(SynapseError):
) )
class IncompatibleRoomVersionError(SynapseError):
"""A server is trying to join a room whose version it does not support."""
def __init__(self, room_version):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
msg="Your homeserver does not support the features required to "
"join this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
self._room_version = room_version
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
room_version=self._room_version,
)
def cs_error(msg, code=Codes.UNKNOWN, **kwargs): def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server """ Utility method for constructing an error response for client-server
interactions. interactions.

View File

@ -519,17 +519,26 @@ def run(hs):
# table will decrease # table will decrease
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
# monthly active user limiting functionality
clock.looping_call(
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
)
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
count = 0 count = 0
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().count_monthly_users() count = yield hs.get_datastore().get_monthly_active_count()
current_mau_gauge.set(float(count)) current_mau_gauge.set(float(count))
max_mau_value_gauge.set(float(hs.config.max_mau_value)) max_mau_value_gauge.set(float(hs.config.max_mau_value))
hs.get_datastore().initialise_reserved_users(
hs.config.mau_limits_reserved_threepids
)
generate_monthly_active_users() generate_monthly_active_users()
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
if hs.config.report_stats: if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals") logger.info("Scheduling stats reporting for 3 hour intervals")

View File

@ -69,12 +69,15 @@ class ServerConfig(Config):
# Options to control access by tracking MAU # Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
if self.limit_usage_by_mau: if self.limit_usage_by_mau:
self.max_mau_value = config.get( self.max_mau_value = config.get(
"max_mau_value", 0, "max_mau_value", 0,
) )
else: self.mau_limits_reserved_threepids = config.get(
self.max_mau_value = 0 "mau_limit_reserved_threepids", []
)
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(

View File

@ -20,7 +20,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
@ -83,6 +83,14 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
403, 403,
"Creation event's room_id domain does not match sender's" "Creation event's room_id domain does not match sender's"
) )
room_version = event.content.get("room_version", "1")
if room_version not in KNOWN_ROOM_VERSIONS:
raise AuthError(
403,
"room appears to have unsupported version %s" % (
room_version,
))
# FIXME # FIXME
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
return return

View File

@ -25,7 +25,7 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
FederationDeniedError, FederationDeniedError,
@ -518,10 +518,10 @@ class FederationClient(FederationBase):
description, destination, exc_info=1, description, destination, exc_info=1,
) )
raise RuntimeError("Failed to %s via any server", description) raise RuntimeError("Failed to %s via any server" % (description, ))
def make_membership_event(self, destinations, room_id, user_id, membership, def make_membership_event(self, destinations, room_id, user_id, membership,
content={},): content, params):
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -537,8 +537,10 @@ class FederationClient(FederationBase):
user_id (str): The user whose membership is being evented. user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be membership (str): The "membership" property of the event. Must be
one of "join" or "leave". one of "join" or "leave".
content (object): Any additional data to put into the content field content (dict): Any additional data to put into the content field
of the event. of the event.
params (dict[str, str|Iterable[str]]): Query parameters to include in the
request.
Return: Return:
Deferred: resolves to a tuple of (origin (str), event (object)) Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event. where origin is the remote homeserver which generated the event.
@ -558,10 +560,12 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(destination): def send_request(destination):
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership destination, room_id, user_id, membership, params,
) )
pdu_dict = ret["event"] pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
logger.debug("Got response to make_%s: %s", membership, pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
@ -605,6 +609,26 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
def check_authchain_validity(signed_auth_chain):
for e in signed_auth_chain:
if e.type == EventTypes.Create:
create_event = e
break
else:
raise InvalidResponseError(
"no %s in auth chain" % (EventTypes.Create,),
)
# the room version should be sane.
room_version = create_event.content.get("room_version", "1")
if room_version not in KNOWN_ROOM_VERSIONS:
# This shouldn't be possible, because the remote server should have
# rejected the join attempt during make_join.
raise InvalidResponseError(
"room appears to have unsupported version %s" % (
room_version,
))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(destination): def send_request(destination):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -661,7 +685,7 @@ class FederationClient(FederationBase):
for s in signed_state: for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata) s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth) check_authchain_validity(signed_auth)
defer.returnValue({ defer.returnValue({
"state": signed_state, "state": signed_state,

View File

@ -27,7 +27,13 @@ from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError from synapse.api.errors import (
AuthError,
FederationError,
IncompatibleRoomVersionError,
NotFoundError,
SynapseError,
)
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
@ -327,12 +333,21 @@ class FederationServer(FederationBase):
defer.returnValue((200, resp)) defer.returnValue((200, resp))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, origin, room_id, user_id): def on_make_join_request(self, origin, room_id, user_id, supported_versions):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id) yield self.check_server_matches_acl(origin_host, room_id)
room_version = yield self.store.get_room_version(room_id)
if room_version not in supported_versions:
logger.warn("Room version %s not in %s", room_version, supported_versions)
raise IncompatibleRoomVersionError(room_version=room_version)
pdu = yield self.handler.on_make_join_request(room_id, user_id) pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content): def on_invite_request(self, origin, content):

View File

@ -26,6 +26,8 @@ from synapse.api.errors import FederationDeniedError, HttpResponseException
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import ( from synapse.metrics import (
LaterGauge, LaterGauge,
event_processing_loop_counter,
event_processing_loop_room_count,
events_processed_counter, events_processed_counter,
sent_edus_counter, sent_edus_counter,
sent_transactions_counter, sent_transactions_counter,
@ -253,7 +255,13 @@ class TransactionQueue(object):
synapse.metrics.event_processing_last_ts.labels( synapse.metrics.event_processing_last_ts.labels(
"federation_sender").set(ts) "federation_sender").set(ts)
events_processed_counter.inc(len(events)) events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels(
"federation_sender"
).inc(len(events_by_room))
event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels( synapse.metrics.event_processing_positions.labels(
"federation_sender").set(next_token) "federation_sender").set(next_token)

View File

@ -195,7 +195,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_membership_event(self, destination, room_id, user_id, membership): def make_membership_event(self, destination, room_id, user_id, membership, params):
"""Asks a remote server to build and sign us a membership event """Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs. Note that this does not append any events to any graphs.
@ -205,6 +205,8 @@ class TransportLayerClient(object):
room_id (str): room to join/leave room_id (str): room to join/leave
user_id (str): user to be joined/left user_id (str): user to be joined/left
membership (str): one of join/leave membership (str): one of join/leave
params (dict[str, str|Iterable[str]]): Query parameters to include in the
request.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -241,6 +243,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=params,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=20000, timeout=20000,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,

View File

@ -190,6 +190,41 @@ def _parse_auth_header(header_bytes):
class BaseFederationServlet(object): class BaseFederationServlet(object):
"""Abstract base class for federation servlet classes.
The servlet object should have a PATH attribute which takes the form of a regexp to
match against the request path (excluding the /federation/v1 prefix).
The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
the appropriate HTTP method. These methods have the signature:
on_<METHOD>(self, origin, content, query, **kwargs)
With arguments:
origin (unicode|None): The authenticated server_name of the calling server,
unless REQUIRE_AUTH is set to False and authentication failed.
content (unicode|None): decoded json body of the request. None if the
request was a GET.
query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
(ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
yet.
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
Returns:
Deferred[(int, object)|None]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
Raises:
SynapseError: to return an error code
Exception: other exceptions will be caught, logged, and a 500 will be
returned.
"""
REQUIRE_AUTH = True REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name): def __init__(self, handler, authenticator, ratelimiter, server_name):
@ -204,6 +239,18 @@ class BaseFederationServlet(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(func) @functools.wraps(func)
def new_func(request, *args, **kwargs): def new_func(request, *args, **kwargs):
""" A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
*args: unused?
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
Returns:
Deferred[(int, object)|None]: (response code, response object) as returned
by the callback method. None if the request has already been handled.
"""
content = None content = None
if request.method in ["PUT", "POST"]: if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types? # TODO: Handle other method types? other content types?
@ -384,9 +431,31 @@ class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, _content, query, context, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
_content (None): (GETs don't have bodies)
query (dict[bytes, list[bytes]]): Query params from the request.
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
Returns:
Deferred[(int, object)|None]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
"""
versions = query.get(b'ver')
if versions is not None:
supported_versions = [v.decode("utf-8") for v in versions]
else:
supported_versions = ["1"]
content = yield self.handler.on_make_join_request( content = yield self.handler.on_make_join_request(
origin, context, user_id, origin, context, user_id,
supported_versions=supported_versions,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))

View File

@ -23,6 +23,10 @@ from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -136,6 +140,12 @@ class ApplicationServicesHandler(object):
events_processed_counter.inc(len(events)) events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels(
"appservice_sender"
).inc(len(events_by_room))
event_processing_loop_counter.labels("appservice_sender").inc()
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"appservice_sender").set(now - ts) "appservice_sender").set(now - ts)
synapse.metrics.event_processing_last_ts.labels( synapse.metrics.event_processing_last_ts.labels(

View File

@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -734,7 +734,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -907,19 +907,6 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Ensure that if mau blocking is enabled that invalid users cannot
log in.
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View File

@ -30,7 +30,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import (
KNOWN_ROOM_VERSIONS,
EventTypes,
Membership,
RejectedReason,
)
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
CodeMessageException, CodeMessageException,
@ -935,6 +940,9 @@ class FederationHandler(BaseHandler):
joinee, joinee,
"join", "join",
content, content,
params={
"ver": KNOWN_ROOM_VERSIONS,
},
) )
# This shouldn't happen, because the RoomMemberHandler has a # This shouldn't happen, because the RoomMemberHandler has a
@ -1200,13 +1208,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},): content={}, params=None):
origin, pdu = yield self.federation_client.make_membership_event( origin, pdu = yield self.federation_client.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
membership, membership,
content, content,
params=params,
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, pdu)

View File

@ -17,7 +17,13 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, CodeMessageException, SynapseError from synapse.api.errors import (
AuthError,
CodeMessageException,
Codes,
StoreError,
SynapseError,
)
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
@ -49,12 +55,17 @@ class ProfileHandler(BaseHandler):
def get_profile(self, user_id): def get_profile(self, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname( try:
target_user.localpart displayname = yield self.store.get_profile_displayname(
) target_user.localpart
avatar_url = yield self.store.get_profile_avatar_url( )
target_user.localpart avatar_url = yield self.store.get_profile_avatar_url(
) target_user.localpart
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
defer.returnValue({ defer.returnValue({
"displayname": displayname, "displayname": displayname,
@ -74,7 +85,6 @@ class ProfileHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
logger.exception("Failed to get displayname") logger.exception("Failed to get displayname")
raise raise
@defer.inlineCallbacks @defer.inlineCallbacks
@ -85,12 +95,17 @@ class ProfileHandler(BaseHandler):
""" """
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname( try:
target_user.localpart displayname = yield self.store.get_profile_displayname(
) target_user.localpart
avatar_url = yield self.store.get_profile_avatar_url( )
target_user.localpart avatar_url = yield self.store.get_profile_avatar_url(
) target_user.localpart
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
defer.returnValue({ defer.returnValue({
"displayname": displayname, "displayname": displayname,
@ -103,9 +118,14 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname( try:
target_user.localpart displayname = yield self.store.get_profile_displayname(
) target_user.localpart
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
defer.returnValue(displayname) defer.returnValue(displayname)
else: else:
@ -122,7 +142,6 @@ class ProfileHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
logger.exception("Failed to get displayname") logger.exception("Failed to get displayname")
raise raise
except Exception: except Exception:
logger.exception("Failed to get displayname") logger.exception("Failed to get displayname")
@ -157,10 +176,14 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
avatar_url = yield self.store.get_profile_avatar_url( try:
target_user.localpart avatar_url = yield self.store.get_profile_avatar_url(
) target_user.localpart
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
defer.returnValue(avatar_url) defer.returnValue(avatar_url)
else: else:
try: try:
@ -213,16 +236,20 @@ class ProfileHandler(BaseHandler):
just_field = args.get("field", None) just_field = args.get("field", None)
response = {} response = {}
try:
if just_field is None or just_field == "displayname":
response["displayname"] = yield self.store.get_profile_displayname(
user.localpart
)
if just_field is None or just_field == "displayname": if just_field is None or just_field == "avatar_url":
response["displayname"] = yield self.store.get_profile_displayname( response["avatar_url"] = yield self.store.get_profile_avatar_url(
user.localpart user.localpart
) )
except StoreError as e:
if just_field is None or just_field == "avatar_url": if e.code == 404:
response["avatar_url"] = yield self.store.get_profile_avatar_url( raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
user.localpart raise
)
defer.returnValue(response) defer.returnValue(response)

View File

@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler):
Do not accept registrations if monthly active user limits exceeded Do not accept registrations if monthly active user limits exceeded
and limiting is enabled and limiting is enabled
""" """
if self.hs.config.limit_usage_by_mau is True: try:
current_mau = yield self.store.count_monthly_users() yield self.auth.check_auth_blocking()
if current_mau >= self.hs.config.max_mau_value: except AuthError as e:
raise RegistrationError( raise RegistrationError(e.code, str(e), e.errcode)
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -21,9 +21,17 @@ import math
import string import string
from collections import OrderedDict from collections import OrderedDict
from six import string_types
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.constants import (
DEFAULT_ROOM_VERSION,
KNOWN_ROOM_VERSIONS,
EventTypes,
JoinRules,
RoomCreationPreset,
)
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils from synapse.util import stringutils
@ -99,6 +107,21 @@ class RoomCreationHandler(BaseHandler):
if ratelimit: if ratelimit:
yield self.ratelimit(requester) yield self.ratelimit(requester)
room_version = config.get("room_version", DEFAULT_ROOM_VERSION)
if not isinstance(room_version, string_types):
raise SynapseError(
400,
"room_version must be a string",
Codes.BAD_JSON,
)
if room_version not in KNOWN_ROOM_VERSIONS:
raise SynapseError(
400,
"Your homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace: for wchar in string.whitespace:
if wchar in config["room_alias_name"]: if wchar in config["room_alias_name"]:
@ -184,6 +207,9 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version
room_member_handler = self.hs.get_room_member_handler() room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room( yield self._send_events_for_new_room(

View File

@ -439,7 +439,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True, def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
timeout=None, ignore_backoff=False): timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
@ -447,7 +447,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to args (dict|None): A dictionary used to create query strings, defaults to
None. None.
timeout (int): How long to try (in ms) the destination for before timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will giving up. None indicates no timeout and that the request will
@ -702,6 +702,9 @@ def check_content_type_is_json(headers):
def encode_query_args(args): def encode_query_args(args):
if args is None:
return b""
encoded_args = {} encoded_args = {}
for k, vs in args.items(): for k, vs in args.items():
if isinstance(vs, string_types): if isinstance(vs, string_types):

View File

@ -174,6 +174,19 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions
events_processed_counter = Counter("synapse_federation_client_events_processed", "") events_processed_counter = Counter("synapse_federation_client_events_processed", "")
event_processing_loop_counter = Counter(
"synapse_event_processing_loop_count",
"Event processing loop iterations",
["name"],
)
event_processing_loop_room_count = Counter(
"synapse_event_processing_loop_room_count",
"Rooms seen per event processing loop iteration",
["name"],
)
# Used to track where various components have processed in the event stream, # Used to track where various components have processed in the event stream,
# e.g. federation sending, appservice sending, etc. # e.g. federation sending, appservice sending, etc.
event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"]) event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"])

View File

@ -40,8 +40,8 @@ class ReplicationEndpoint(object):
/_synapse/replication/send_event/:event_id/:txn_id /_synapse/replication/send_event/:event_id/:txn_id
For POST requests the payload is serialized to json and sent as the body, For POST/PUT requests the payload is serialized to json and sent as the
while for GET requests the payload is added as query parameters. See body, while for GET requests the payload is added as query parameters. See
`_serialize_payload` for details. `_serialize_payload` for details.
Incoming requests are handled by overriding `_handle_request`. Servers Incoming requests are handled by overriding `_handle_request`. Servers
@ -55,8 +55,9 @@ class ReplicationEndpoint(object):
PATH_ARGS (tuple[str]): A list of parameters to be added to the path. PATH_ARGS (tuple[str]): A list of parameters to be added to the path.
Adding parameters to the path (rather than payload) can make it Adding parameters to the path (rather than payload) can make it
easier to follow along in the log files. easier to follow along in the log files.
POST (bool): True to use POST request with JSON body, or false to use METHOD (str): The method of the HTTP request, defaults to POST. Can be
GET requests with query params. one of POST, PUT or GET. If GET then the payload is sent as query
parameters rather than a JSON body.
CACHE (bool): Whether server should cache the result of the request/ CACHE (bool): Whether server should cache the result of the request/
If true then transparently adds a txn_id to all requests, and If true then transparently adds a txn_id to all requests, and
`_handle_request` must return a Deferred. `_handle_request` must return a Deferred.
@ -69,7 +70,7 @@ class ReplicationEndpoint(object):
NAME = abc.abstractproperty() NAME = abc.abstractproperty()
PATH_ARGS = abc.abstractproperty() PATH_ARGS = abc.abstractproperty()
POST = True METHOD = "POST"
CACHE = True CACHE = True
RETRY_ON_TIMEOUT = True RETRY_ON_TIMEOUT = True
@ -80,6 +81,8 @@ class ReplicationEndpoint(object):
timeout_ms=30 * 60 * 1000, timeout_ms=30 * 60 * 1000,
) )
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod @abc.abstractmethod
def _serialize_payload(**kwargs): def _serialize_payload(**kwargs):
"""Static method that is called when creating a request. """Static method that is called when creating a request.
@ -90,9 +93,9 @@ class ReplicationEndpoint(object):
argument list. argument list.
Returns: Returns:
Deferred[dict]|dict: If POST request then dictionary must be JSON Deferred[dict]|dict: If POST/PUT request then dictionary must be
serialisable, otherwise must be appropriate for adding as query JSON serialisable, otherwise must be appropriate for adding as
args. query args.
""" """
return {} return {}
@ -130,10 +133,18 @@ class ReplicationEndpoint(object):
txn_id = random_string(10) txn_id = random_string(10)
url_args.append(txn_id) url_args.append(txn_id)
if cls.POST: if cls.METHOD == "POST":
request_func = client.post_json_get_json request_func = client.post_json_get_json
else: elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
request_func = client.get_json request_func = client.get_json
else:
# We have already asserted in the constructor that a
# compatible was picked, but lets be paranoid.
raise Exception(
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % ( uri = "http://%s:%s/_synapse/replication/%s/%s" % (
host, port, cls.NAME, "/".join(url_args) host, port, cls.NAME, "/".join(url_args)
@ -151,7 +162,7 @@ class ReplicationEndpoint(object):
if e.code != 504 or not cls.RETRY_ON_TIMEOUT: if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise raise
logger.warn("send_federation_events_to_master request timed out") logger.warn("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.
@ -172,10 +183,8 @@ class ReplicationEndpoint(object):
""" """
url_args = list(self.PATH_ARGS) url_args = list(self.PATH_ARGS)
method = "GET"
handler = self._handle_request handler = self._handle_request
if self.POST: method = self.METHOD
method = "POST"
if self.CACHE: if self.CACHE:
handler = self._cached_handler handler = self._cached_handler
@ -190,7 +199,9 @@ class ReplicationEndpoint(object):
http_server.register_paths(method, [pattern], handler) http_server.register_paths(method, [pattern], handler)
def _cached_handler(self, request, txn_id, **kwargs): def _cached_handler(self, request, txn_id, **kwargs):
"""Wraps `_handle_request` the responses should be cached. """Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
""" """
# We just use the txn_id here, but we probably also want to use the # We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well. # other PATH_ARGS as well.

View File

@ -27,6 +27,16 @@ logger = logging.getLogger(__name__)
class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"""Does a remote join for the given user to the given room """Does a remote join for the given user to the given room
Request format:
POST /_synapse/replication/remote_join/:room_id/:user_id
{
"requester": ...,
"remote_room_hosts": [...],
"content": { ... }
}
""" """
NAME = "remote_join" NAME = "remote_join"
@ -85,6 +95,15 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects the invite for the user and room. """Rejects the invite for the user and room.
Request format:
POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
{
"requester": ...,
"remote_room_hosts": [...],
}
""" """
NAME = "remote_reject_invite" NAME = "remote_reject_invite"
@ -153,6 +172,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint): class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
"""Gets/creates a guest account for given 3PID. """Gets/creates a guest account for given 3PID.
Request format:
POST /_synapse/replication/get_or_register_3pid_guest/
{
"requester": ...,
"medium": ...,
"address": ...,
"inviter_user_id": ...
}
""" """
NAME = "get_or_register_3pid_guest" NAME = "get_or_register_3pid_guest"
@ -206,6 +236,12 @@ class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room """Notifies that a user has joined or left the room
Request format:
POST /_synapse/replication/membership_change/:room_id/:user_id/:change
{}
""" """
NAME = "membership_change" NAME = "membership_change"

View File

@ -47,7 +47,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
""" """
NAME = "send_event" NAME = "send_event"
PATH_ARGS = ("event_id",) PATH_ARGS = ("event_id",)
POST = True
def __init__(self, hs): def __init__(self, hs):
super(ReplicationSendEventRestServlet, self).__init__(hs) super(ReplicationSendEventRestServlet, self).__init__(hs)

View File

@ -44,8 +44,8 @@ class SlavedEventStore(EventFederationWorkerStore,
RoomMemberWorkerStore, RoomMemberWorkerStore,
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
StreamWorkerStore, StreamWorkerStore,
EventsWorkerStore,
StateGroupWorkerStore, StateGroupWorkerStore,
EventsWorkerStore,
SignatureWorkerStore, SignatureWorkerStore,
UserErasureWorkerStore, UserErasureWorkerStore,
BaseSlavedStore): BaseSlavedStore):

View File

@ -39,6 +39,7 @@ from .filtering import FilteringStore
from .group_server import GroupServerStore from .group_server import GroupServerStore
from .keys import KeyStore from .keys import KeyStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore from .profile import ProfileStore
@ -87,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
UserDirectoryStore, UserDirectoryStore,
GroupServerStore, GroupServerStore,
UserErasureStore, UserErasureStore,
MonthlyActiveUsersStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -94,7 +96,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")] extra_tables=[("local_invites", "stream_id")]
@ -267,31 +268,6 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
def count_monthly_users(self):
"""Counts the number of users who used this homeserver in the last 30 days
This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau
Returns:
Defered[int]
"""
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone()
return count
return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self): def count_r30_users(self):
""" """
Counts the number of 30 day retained users, defined as:- Counts the number of 30 day retained users, defined as:-

View File

@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore): class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch "before", "shutdown", self._update_client_ips_batch
) )
@defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None): now=None):
if not now: if not now:
@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
yield self.populate_monthly_active_users(user_id)
# Rate-limited inserts # Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return return

View File

@ -485,9 +485,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
new_forward_extremeties=new_forward_extremeties, new_forward_extremeties=new_forward_extremeties,
) )
persist_event_counter.inc(len(chunk)) persist_event_counter.inc(len(chunk))
synapse.metrics.event_persisted_position.set(
chunk[-1][0].internal_metadata.stream_ordering, if not backfilled:
) # backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
chunk[-1][0].internal_metadata.stream_ordering,
)
for event, context in chunk: for event, context in chunk:
if context.app_service: if context.app_service:
origin_type = "local" origin_type = "local"

View File

@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
# This means it is not necessary to update the table on every request
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersStore(SQLBaseStore):
def __init__(self, dbconn, hs):
super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock()
self.hs = hs
self.reserved_users = ()
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
# TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []
# Do not add more reserved users than the total allowable number
for tp in threepids[:self.hs.config.max_mau_value]:
user_id = yield store.get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
self.upsert_monthly_active_user(user_id)
reserved_user_list.append(user_id)
else:
logger.warning(
"mau limit reserved threepid %s not found in db" % tp
)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
"""
Cleans out monthly active user table to ensure that no stale
entries exist.
Returns:
Deferred[]
"""
def _reap_users(txn):
thirty_days_ago = (
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
)
# Purge stale users
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
questionmarks = '?' * len(self.reserved_users)
query_args = [thirty_days_ago]
query_args.extend(self.reserved_users)
sql = """
DELETE FROM monthly_active_users
WHERE timestamp < ?
AND user_id NOT IN ({})
""".format(','.join(questionmarks))
txn.execute(sql, query_args)
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
# sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
query_args = [self.hs.config.max_mau_value]
query_args.extend(self.reserved_users)
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
)
AND user_id NOT IN ({})
""".format(','.join(questionmarks))
txn.execute(sql, query_args)
yield self.runInteraction("reap_monthly_active_users", _reap_users)
# It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self._user_last_seen_monthly_active.invalidate_all()
self.get_monthly_active_count.invalidate_all()
@cached(num_args=0)
def get_monthly_active_count(self):
"""Generates current count of monthly active users
Returns:
Defered[int]: Number of current monthly active users
"""
def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
count, = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
def upsert_monthly_active_user(self, user_id):
"""
Updates or inserts monthly active user member
Arguments:
user_id (str): user to add/update
Deferred[bool]: True if a new entry was created, False if an
existing one was updated.
"""
is_insert = self._simple_upsert(
desc="upsert_monthly_active_user",
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
values={
"timestamp": int(self._clock.time_msec()),
},
lock=False,
)
if is_insert:
self._user_last_seen_monthly_active.invalidate((user_id,))
self.get_monthly_active_count.invalidate(())
@cached(num_args=1)
def _user_last_seen_monthly_active(self, user_id):
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
"""
return(self._simple_select_one_onecol(
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
retcol="timestamp",
allow_none=True,
desc="_user_last_seen_monthly_active",
))
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
Args:
user_id(str): the user_id to query
"""
if self.hs.config.limit_usage_by_mau:
last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
# to trade accuracy of timestamp in order to lighten load. This means
# We always insert new users (where MAU threshold has not been reached),
# but only update if we have not previously seen the user for
# LAST_SEEN_GRANULARITY ms
if last_seen_timestamp is None:
count = yield self.get_monthly_active_count()
if count < self.hs.config.max_mau_value:
yield self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 50 SCHEMA_VERSION = 51
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -0,0 +1,27 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- a table of monthly active users, for use where blocking based on mau limits
CREATE TABLE monthly_active_users (
user_id TEXT NOT NULL,
-- Last time we saw the user. Not guaranteed to be accurate due to rate limiting
-- on updates, Granularity of updates governed by
-- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY
-- Measured in ms since epoch.
timestamp BIGINT NOT NULL
);
CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id);
CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp);

View File

@ -21,15 +21,17 @@ from six.moves import range
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.caches import get_cache_factor_for, intern_string from synapse.util.caches import get_cache_factor_for, intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii from synapse.util.stringutils import to_ascii
from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateGroupWorkerStore(SQLBaseStore): # this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers. """The parts of StateGroupStore that can be called from workers.
""" """
@ -61,6 +64,30 @@ class StateGroupWorkerStore(SQLBaseStore):
"*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
) )
@defer.inlineCallbacks
def get_room_version(self, room_id):
"""Get the room_version of a given room
Args:
room_id (str)
Returns:
Deferred[str]
Raises:
NotFoundError if the room is unknown
"""
# for now we do this by looking at the create event. We may want to cache this
# more intelligently in future.
state_ids = yield self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
if not create_id:
raise NotFoundError("Unknown room")
create_event = yield self.get_event(create_id)
defer.returnValue(create_event.content.get("room_version", "1"))
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id): def get_current_state_ids(self, room_id):
"""Get the current state event ids for a room based on the """Get the current state event ids for a room based on the

View File

@ -444,3 +444,28 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.assertEqual("Guest access token used for regular user", cm.exception.msg)
self.store.get_user_by_id.assert_called_with(USER_ID) self.store.get_user_by_id.assert_called_with(USER_ID)
@defer.inlineCallbacks
def test_blocking_mau(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_of_users = 1
# Ensure no error thrown
yield self.auth.check_auth_blocking()
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
with self.assertRaises(AuthError):
yield self.auth.check_auth_blocking()
# Ensure does not throw an error
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
yield self.auth.check_auth_blocking()

View File

@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_exceeded(self): def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
) )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(

View File

@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase):
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.store = self.hs.get_datastore()
self.hs.config.max_mau_value = 50
self.lots_of_users = 100
self.small_number_of_users = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):
@ -80,51 +84,43 @@ class RegistrationTestCase(unittest.TestCase):
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cannot_register_when_mau_limits_exceeded(self): def test_mau_limits_when_disabled(self):
local_part = "someone"
display_name = "someone"
requester = create_requester("@as:test")
store = self.hs.get_datastore()
self.hs.config.limit_usage_by_mau = False self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_users = 1
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user(requester, 'a', display_name) yield self.handler.get_or_create_user("requester", 'a', "display_name")
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
with self.assertRaises(RegistrationError): return_value=defer.succeed(self.small_number_of_users)
yield self.handler.get_or_create_user(requester, 'b', display_name) )
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
self._macaroon_mock_generator("another_secret")
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") yield self.handler.get_or_create_user("@user:server", 'c', "User")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.register(localpart=local_part) yield self.handler.get_or_create_user("requester", 'b', "display_name")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
@defer.inlineCallbacks
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.register_saml2(local_part) yield self.handler.register(localpart="local_part")
def _macaroon_mock_generator(self, secret): @defer.inlineCallbacks
""" def test_register_saml2_mau_blocked(self):
Reset macaroon generator in the case where the test creates multiple users self.hs.config.limit_usage_by_mau = True
""" self.store.get_monthly_active_count = Mock(
macaroon_generator = Mock( return_value=defer.succeed(self.lots_of_users)
generate_access_token=Mock(return_value=secret)) )
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) with self.assertRaises(RegistrationError):
self.hs.handlers = RegistrationHandlers(self.hs) yield self.handler.register_saml2(localpart="local_part")
self.handler = self.hs.get_handlers().registration_handler

View File

@ -48,7 +48,9 @@ def _expect_edu(destination, edu_type, content, origin="test"):
def _make_edu_json(origin, edu_type, content): def _make_edu_json(origin, edu_type, content):
return json.dumps(_expect_edu("test", edu_type, content, origin=origin)) return json.dumps(
_expect_edu("test", edu_type, content, origin=origin)
).encode('utf8')
class TypingNotificationsTestCase(unittest.TestCase): class TypingNotificationsTestCase(unittest.TestCase):

View File

@ -85,7 +85,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
try: try:
yield self.cache.fetch_or_execute(self.mock_key, cb) yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e: except Exception as e:
self.assertEqual(e.message, "boo") self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context) self.assertIs(LoggingContext.current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute(self.mock_key, cb)
@ -111,7 +111,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
try: try:
yield self.cache.fetch_or_execute(self.mock_key, cb) yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e: except Exception as e:
self.assertEqual(e.message, "boo") self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context) self.assertIs(LoggingContext.current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute(self.mock_key, cb)

View File

@ -140,7 +140,7 @@ class UserRegisterTestCase(unittest.TestCase):
"admin": True, "admin": True,
"mac": want_mac, "mac": want_mac,
} }
).encode('utf8') )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) render(request, self.resource, self.clock)
@ -168,7 +168,7 @@ class UserRegisterTestCase(unittest.TestCase):
"admin": True, "admin": True,
"mac": want_mac, "mac": want_mac,
} }
).encode('utf8') )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) render(request, self.resource, self.clock)
@ -195,7 +195,7 @@ class UserRegisterTestCase(unittest.TestCase):
"admin": True, "admin": True,
"mac": want_mac, "mac": want_mac,
} }
).encode('utf8') )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) render(request, self.resource, self.clock)
@ -253,7 +253,7 @@ class UserRegisterTestCase(unittest.TestCase):
self.assertEqual('Invalid username', channel.json_body["error"]) self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"}) body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) render(request, self.resource, self.clock)
@ -289,7 +289,7 @@ class UserRegisterTestCase(unittest.TestCase):
self.assertEqual('Invalid password', channel.json_body["error"]) self.assertEqual('Invalid password', channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"}) body = json.dumps({"nonce": nonce(), "username": "a", "password": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) render(request, self.resource, self.clock)

View File

@ -80,7 +80,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "PUT",
"/profile/%s/displayname" % (myid), "/profile/%s/displayname" % (myid),
'{"displayname": "Frank Jr."}' b'{"displayname": "Frank Jr."}'
) )
self.assertEquals(200, code) self.assertEquals(200, code)
@ -95,7 +95,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % ("@4567:test"), "PUT", "/profile/%s/displayname" % ("@4567:test"),
'{"displayname": "Frank Jr."}' b'{"displayname": "Frank Jr."}'
) )
self.assertTrue( self.assertTrue(
@ -122,7 +122,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"),
'{"displayname":"bob"}' b'{"displayname":"bob"}'
) )
self.assertTrue( self.assertTrue(
@ -151,7 +151,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "PUT",
"/profile/%s/avatar_url" % (myid), "/profile/%s/avatar_url" % (myid),
'{"avatar_url": "http://my.server/pic.gif"}' b'{"avatar_url": "http://my.server/pic.gif"}'
) )
self.assertEquals(200, code) self.assertEquals(200, code)

View File

@ -105,7 +105,7 @@ class RestTestCase(unittest.TestCase):
"password": "test", "password": "test",
"type": "m.login.password" "type": "m.login.password"
})) }))
self.assertEquals(200, code) self.assertEquals(200, code, msg=response)
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -149,14 +149,14 @@ class RestHelper(object):
def create_room_as(self, room_creator, is_public=True, tok=None): def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id temp_id = self.auth_user_id
self.auth_user_id = room_creator self.auth_user_id = room_creator
path = b"/_matrix/client/r0/createRoom" path = "/_matrix/client/r0/createRoom"
content = {} content = {}
if not is_public: if not is_public:
content["visibility"] = "private" content["visibility"] = "private"
if tok: if tok:
path = path + b"?access_token=%s" % tok.encode('ascii') path = path + "?access_token=%s" % tok
request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8')) request, channel = make_request("POST", path, json.dumps(content).encode('utf8'))
request.render(self.resource) request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel) wait_until_result(self.hs.get_reactor(), channel)
@ -205,7 +205,7 @@ class RestHelper(object):
data = {"membership": membership} data = {"membership": membership}
request, channel = make_request( request, channel = make_request(
b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8') "PUT", path, json.dumps(data).encode('utf8')
) )
request.render(self.resource) request.render(self.resource)

View File

@ -33,7 +33,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = b"@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
@ -72,8 +72,8 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter(self): def test_add_filter(self):
request, channel = make_request( request, channel = make_request(
b"POST", "POST",
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource) request.render(self.resource)
@ -87,8 +87,8 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter_for_other_user(self): def test_add_filter_for_other_user(self):
request, channel = make_request( request, channel = make_request(
b"POST", "POST",
b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"), "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource) request.render(self.resource)
@ -101,8 +101,8 @@ class FilterTestCase(unittest.TestCase):
_is_mine = self.hs.is_mine _is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False self.hs.is_mine = lambda target_user: False
request, channel = make_request( request, channel = make_request(
b"POST", "POST",
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource) request.render(self.resource)
@ -119,7 +119,7 @@ class FilterTestCase(unittest.TestCase):
self.clock.advance(1) self.clock.advance(1)
filter_id = filter_id.result filter_id = filter_id.result
request, channel = make_request( request, channel = make_request(
b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
) )
request.render(self.resource) request.render(self.resource)
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
@ -129,7 +129,7 @@ class FilterTestCase(unittest.TestCase):
def test_get_filter_non_existant(self): def test_get_filter_non_existant(self):
request, channel = make_request( request, channel = make_request(
b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
) )
request.render(self.resource) request.render(self.resource)
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
@ -141,7 +141,7 @@ class FilterTestCase(unittest.TestCase):
# in errors.py # in errors.py
def test_get_filter_invalid_id(self): def test_get_filter_invalid_id(self):
request, channel = make_request( request, channel = make_request(
b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
) )
request.render(self.resource) request.render(self.resource)
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
@ -151,7 +151,7 @@ class FilterTestCase(unittest.TestCase):
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
def test_get_filter_no_id(self): def test_get_filter_no_id(self):
request, channel = make_request( request, channel = make_request(
b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
) )
request.render(self.resource) request.render(self.resource)
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)

View File

@ -81,7 +81,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists self.appservice = None # no application service exists
@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals( self.assertEquals(
json.loads(channel.result["body"])["error"], "Invalid password" channel.json_body["error"], "Invalid password"
) )
def test_POST_bad_username(self): def test_POST_bad_username(self):
@ -113,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals( self.assertEquals(
json.loads(channel.result["body"])["error"], "Invalid username" channel.json_body["error"], "Invalid username"
) )
def test_POST_user_valid(self): def test_POST_user_valid(self):
@ -140,7 +140,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id, "device_id": device_id,
} }
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) self.assertDictContainsSubset(det_data, channel.json_body)
self.auth_handler.get_login_tuple_for_user_id( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None user_id, device_id=device_id, initial_device_display_name=None
) )
@ -158,7 +158,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals( self.assertEquals(
json.loads(channel.result["body"])["error"], channel.json_body["error"],
"Registration has been disabled", "Registration has been disabled",
) )
@ -178,7 +178,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": "guest_device", "device_id": "guest_device",
} }
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self): def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False self.hs.config.allow_guest_access = False
@ -189,5 +189,5 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals( self.assertEquals(
json.loads(channel.result["body"])["error"], "Guest access is disabled" channel.json_body["error"], "Guest access is disabled"
) )

View File

@ -32,7 +32,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = b"@apple:test" USER_ID = "@apple:test"
TO_REGISTER = [sync] TO_REGISTER = [sync]
def setUp(self): def setUp(self):
@ -68,7 +68,7 @@ class FilterTestCase(unittest.TestCase):
r.register_servlets(self.hs, self.resource) r.register_servlets(self.hs, self.resource)
def test_sync_argless(self): def test_sync_argless(self):
request, channel = make_request(b"GET", b"/_matrix/client/r0/sync") request, channel = make_request("GET", "/_matrix/client/r0/sync")
request.render(self.resource) request.render(self.resource)
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)

View File

@ -11,6 +11,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth from tests.utils import setup_test_homeserver as _sth
@ -28,7 +29,13 @@ class FakeChannel(object):
def json_body(self): def json_body(self):
if not self.result: if not self.result:
raise Exception("No result yet.") raise Exception("No result yet.")
return json.loads(self.result["body"]) return json.loads(self.result["body"].decode('utf8'))
@property
def code(self):
if not self.result:
raise Exception("No result yet.")
return int(self.result["code"])
def writeHeaders(self, version, code, reason, headers): def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version self.result["version"] = version
@ -79,11 +86,16 @@ def make_request(method, path, content=b""):
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
""" """
if not isinstance(method, bytes):
method = method.encode('ascii')
if not isinstance(path, bytes):
path = path.encode('ascii')
# Decorate it to be the full path # Decorate it to be the full path
if not path.startswith(b"/_matrix"): if not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path path = b"/_matrix/client/r0/" + path
path = path.replace("//", "/") path = path.replace(b"//", b"/")
if isinstance(content, text_type): if isinstance(content, text_type):
content = content.encode('utf8') content = content.encode('utf8')
@ -191,3 +203,9 @@ def setup_test_homeserver(*args, **kwargs):
clock.threadpool = ThreadPool() clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool() pool.threadpool = ThreadPool()
return d return d
def get_clock():
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
return (clock, hs_clock)

View File

@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.utils
class InitTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(InitTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
hs.config.max_mau_value = 50
hs.config.limit_usage_by_mau = True
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_count_monthly_users(self):
count = yield self.store.count_monthly_users()
self.assertEqual(0, count)
yield self._insert_user_ips("@user:server1")
yield self._insert_user_ips("@user:server2")
count = yield self.store.count_monthly_users()
self.assertEqual(2, count)
@defer.inlineCallbacks
def _insert_user_ips(self, user):
"""
Helper function to populate user_ips without using batch insertion infra
args:
user (str): specify username i.e. @user:server.com
"""
yield self.store._simple_upsert(
table="user_ips",
keyvalues={
"user_id": user,
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": "device_id",
},
values={
"last_seen": self.clock.time_msec(),
}
)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -27,9 +28,9 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield tests.utils.setup_test_homeserver() self.hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = hs.get_clock() self.clock = self.hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_new_client_ip(self): def test_insert_new_client_ip(self):
@ -54,3 +55,62 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
}, },
r r
) )
@defer.inlineCallbacks
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
user_id = "@user:server"
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
@defer.inlineCallbacks
def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
lots_of_users = 100
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
@defer.inlineCallbacks
def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertTrue(active)
@defer.inlineCallbacks
def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertTrue(active)

View File

@ -49,7 +49,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
'INSERT INTO event_reference_hashes ' 'INSERT INTO event_reference_hashes '
'(event_id, algorithm, hash) ' '(event_id, algorithm, hash) '
"VALUES (?, 'sha256', ?)" "VALUES (?, 'sha256', ?)"
), (event_id, 'ffff')) ), (event_id, b'ffff'))
for i in range(0, 11): for i in range(0, 11):
yield self.store.runInteraction("insert", insert_event, i) yield self.store.runInteraction("insert", insert_event, i)

View File

@ -0,0 +1,123 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
FORTY_DAYS = 40 * 24 * 60 * 60
class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver()
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_initialise_reserved_users(self):
user1 = "@user1:server"
user1_email = "user1@matrix.org"
user2 = "@user2:server"
user2_email = "user2@matrix.org"
threepids = [
{'medium': 'email', 'address': user1_email},
{'medium': 'email', 'address': user2_email}
]
user_num = len(threepids)
yield self.store.register(
user_id=user1,
token="123",
password_hash=None)
yield self.store.register(
user_id=user2,
token="456",
password_hash=None)
now = int(self.hs.get_clock().time_msec())
yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
yield self.store.initialise_reserved_users(threepids)
active_count = yield self.store.get_monthly_active_count()
# Test total counts
self.assertEquals(active_count, user_num)
# Test user is marked as active
timestamp = yield self.store._user_last_seen_monthly_active(user1)
self.assertTrue(timestamp)
timestamp = yield self.store._user_last_seen_monthly_active(user2)
self.assertTrue(timestamp)
# Test that users are never removed from the db.
self.hs.config.max_mau_value = 0
self.hs.get_clock().advance_time(FORTY_DAYS)
yield self.store.reap_monthly_active_users()
active_count = yield self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num)
@defer.inlineCallbacks
def test_can_insert_and_count_mau(self):
count = yield self.store.get_monthly_active_count()
self.assertEqual(0, count)
yield self.store.upsert_monthly_active_user("@user:server")
count = yield self.store.get_monthly_active_count()
self.assertEqual(1, count)
@defer.inlineCallbacks
def test__user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
result = yield self.store._user_last_seen_monthly_active(user_id1)
self.assertFalse(result == 0)
yield self.store.upsert_monthly_active_user(user_id1)
yield self.store.upsert_monthly_active_user(user_id2)
result = yield self.store._user_last_seen_monthly_active(user_id1)
self.assertTrue(result > 0)
result = yield self.store._user_last_seen_monthly_active(user_id3)
self.assertFalse(result == 0)
@defer.inlineCallbacks
def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
yield self.store.upsert_monthly_active_user("@user%d:server" % i)
count = yield self.store.get_monthly_active_count()
self.assertTrue(count, initial_users)
yield self.store.reap_monthly_active_users()
count = yield self.store.get_monthly_active_count()
self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
self.hs.get_clock().advance_time(FORTY_DAYS)
yield self.store.reap_monthly_active_users()
count = yield self.store.get_monthly_active_count()
self.assertEquals(count, 0)

View File

@ -176,7 +176,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = group_ids.keys()[0] group = list(group_ids.keys())[0]
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(

View File

@ -1,4 +1,3 @@
import json
import re import re
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -104,9 +103,8 @@ class JsonResourceTests(unittest.TestCase):
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.result["code"], b'403')
reply_body = json.loads(channel.result["body"]) self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
self.assertEqual(reply_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")
def test_no_handler(self): def test_no_handler(self):
""" """
@ -126,6 +124,5 @@ class JsonResourceTests(unittest.TestCase):
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.result["code"], b'400')
reply_body = json.loads(channel.result["body"]) self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(reply_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")

View File

@ -73,6 +73,13 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.block_events_without_consent_error = None config.block_events_without_consent_error = None
config.media_storage_providers = [] config.media_storage_providers = []
config.auto_join_rooms = [] config.auto_join_rooms = []
config.limit_usage_by_mau = False
config.max_mau_value = 50
config.mau_limits_reserved_threepids = []
# we need a sane default_room_version, otherwise attempts to create rooms will
# fail.
config.default_room_version = "1"
# disable user directory updates, because they get done in the # disable user directory updates, because they get done in the
# background, which upsets the test runner. # background, which upsets the test runner.
@ -146,8 +153,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
# Need to let the HS build an auth handler and then mess with it # Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one # because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg) # beforehand and pass it in to the HS's constructor (chicken / egg)
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest() hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(
p.encode('utf8')).hexdigest() == h
fed = kargs.get("resource_for_federation", None) fed = kargs.get("resource_for_federation", None)
if fed: if fed:
@ -220,8 +228,8 @@ class MockHttpResource(HttpServer):
mock_content.configure_mock(**config) mock_content.configure_mock(**config)
mock_request.content = mock_content mock_request.content = mock_content
mock_request.method = http_method mock_request.method = http_method.encode('ascii')
mock_request.uri = path mock_request.uri = path.encode('ascii')
mock_request.getClientIP.return_value = "-" mock_request.getClientIP.return_value = "-"