Merge branch 'develop' of github.com:matrix-org/synapse into neilj/server_notices_on_blocking

This commit is contained in:
Neil Johnson 2018-08-15 16:31:40 +01:00
commit fc5d937550
163 changed files with 2945 additions and 2849 deletions

48
.circleci/config.yml Normal file
View File

@ -0,0 +1,48 @@
version: 2
jobs:
sytestpy2:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy2
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2
- store_artifacts:
path: ~/project/logs
destination: logs
sytestpy2postgres:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy2
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2
- store_artifacts:
path: ~/project/logs
destination: logs
sytestpy3:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs hawkowl/sytestpy3
- store_artifacts:
path: ~/project/logs
destination: logs
sytestpy3postgres:
machine: true
steps:
- checkout
- run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3
- store_artifacts:
path: ~/project/logs
destination: logs
workflows:
version: 2
build:
jobs:
- sytestpy2
- sytestpy2postgres
# Currently broken while the Python 3 port is incomplete
# - sytestpy3
# - sytestpy3postgres

View File

@ -3,3 +3,6 @@ Dockerfile
.gitignore .gitignore
demo/etc demo/etc
tox.ini tox.ini
synctl
.git/*
.tox/*

View File

@ -8,6 +8,9 @@ before_script:
- git remote set-branches --add origin develop - git remote set-branches --add origin develop
- git fetch origin develop - git fetch origin develop
services:
- postgresql
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
@ -20,6 +23,9 @@ matrix:
- python: 2.7 - python: 2.7
env: TOX_ENV=py27 env: TOX_ENV=py27
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36
@ -29,6 +35,10 @@ matrix:
- python: 3.6 - python: 3.6
env: TOX_ENV=check-newsfragment env: TOX_ENV=check-newsfragment
allow_failures:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
install: install:
- pip install tox - pip install tox

View File

@ -36,3 +36,4 @@ recursive-include changelog.d *
prune .github prune .github
prune demo/etc prune demo/etc
prune docker prune docker
prune .circleci

1
changelog.d/1491.feature Normal file
View File

@ -0,0 +1 @@
Add support for the SNI extension to federation TLS connections

1
changelog.d/3423.misc Normal file
View File

@ -0,0 +1 @@
The test suite now can run under PostgreSQL.

1
changelog.d/3653.feature Normal file
View File

@ -0,0 +1 @@
Support more federation endpoints on workers

1
changelog.d/3660.misc Normal file
View File

@ -0,0 +1 @@
Sytests can now be run inside a Docker container.

1
changelog.d/3661.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug on deleting 3pid when using identity servers that don't support unbind API

1
changelog.d/3669.misc Normal file
View File

@ -0,0 +1 @@
Update docker base image from alpine 3.7 to 3.8.

1
changelog.d/3676.bugfix Normal file
View File

@ -0,0 +1 @@
Make the tests pass on Twisted < 18.7.0

1
changelog.d/3677.bugfix Normal file
View File

@ -0,0 +1 @@
Dont ship recaptcha_ajax.js, use it directly from Google

1
changelog.d/3678.misc Normal file
View File

@ -0,0 +1 @@
Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7.

1
changelog.d/3679.misc Normal file
View File

@ -0,0 +1 @@
Synapse's tests are now formatted with the black autoformatter.

1
changelog.d/3681.bugfix Normal file
View File

@ -0,0 +1 @@
Fixes test_reap_monthly_active_users so it passes under postgres

1
changelog.d/3684.misc Normal file
View File

@ -0,0 +1 @@
Implemented a new testing base class to reduce test boilerplate.

1
changelog.d/3687.feature Normal file
View File

@ -0,0 +1 @@
set admin uri via config, to be used in error messages where the user should contact the administrator

1
changelog.d/3690.misc Normal file
View File

@ -0,0 +1 @@
Rename MAU prometheus metrics

1
changelog.d/3692.bugfix Normal file
View File

@ -0,0 +1 @@
Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users

View File

@ -1,4 +1,4 @@
FROM docker.io/python:2-alpine3.7 FROM docker.io/python:2-alpine3.8
RUN apk add --no-cache --virtual .nacl_deps \ RUN apk add --no-cache --virtual .nacl_deps \
build-base \ build-base \

View File

@ -173,10 +173,23 @@ endpoints matching the following regular expressions::
^/_matrix/federation/v1/backfill/ ^/_matrix/federation/v1/backfill/
^/_matrix/federation/v1/get_missing_events/ ^/_matrix/federation/v1/get_missing_events/
^/_matrix/federation/v1/publicRooms ^/_matrix/federation/v1/publicRooms
^/_matrix/federation/v1/query/
^/_matrix/federation/v1/make_join/
^/_matrix/federation/v1/make_leave/
^/_matrix/federation/v1/send_join/
^/_matrix/federation/v1/send_leave/
^/_matrix/federation/v1/invite/
^/_matrix/federation/v1/query_auth/
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/send/
The above endpoints should all be routed to the federation_reader worker by the The above endpoints should all be routed to the federation_reader worker by the
reverse-proxy configuration. reverse-proxy configuration.
The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single
instance.
``synapse.app.federation_sender`` ``synapse.app.federation_sender``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -780,11 +780,14 @@ class Auth(object):
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag
Args: Args:
user_id(str): If present, checks for presence against existing MAU cohort user_id(str|None): If present, checks for presence against existing
MAU cohort
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED 403, self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEED,
admin_uri=self.hs.config.admin_uri,
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort # If the user is already part of the MAU cohort
@ -796,5 +799,7 @@ class Auth(object):
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED 403, "Monthly Active User Limits AU Limit Exceeded",
admin_uri=self.hs.config.admin_uri,
errcode=Codes.RESOURCE_LIMIT_EXCEED
) )

View File

@ -56,8 +56,7 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED"
HS_DISABLED = "M_HS_DISABLED"
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
@ -225,11 +224,16 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None):
self.admin_uri = admin_uri
super(AuthError, self).__init__(code, msg, errcode=errcode)
def __init__(self, *args, **kwargs): def error_dict(self):
if "errcode" not in kwargs: return cs_error(
kwargs["errcode"] = Codes.FORBIDDEN self.msg,
super(AuthError, self).__init__(*args, **kwargs) self.errcode,
admin_uri=self.admin_uri,
)
class EventSizeError(SynapseError): class EventSizeError(SynapseError):

View File

@ -39,7 +39,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import ( from synapse.rest.client.v1.room import (
JoinedRoomMemberListRestServlet, JoinedRoomMemberListRestServlet,
@ -66,7 +66,7 @@ class ClientReaderSlavedStore(
DirectoryStore, DirectoryStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore, SlavedTransactionStore,
SlavedClientIpStore, SlavedClientIpStore,
BaseSlavedStore, BaseSlavedStore,
): ):
@ -168,11 +168,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = ClientReaderServer( ss = ClientReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -43,7 +43,7 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import ( from synapse.rest.client.v1.room import (
JoinRoomAliasServlet, JoinRoomAliasServlet,
@ -63,7 +63,7 @@ logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore( class EventCreatorSlavedStore(
DirectoryStore, DirectoryStore,
TransactionStore, SlavedTransactionStore,
SlavedProfileStore, SlavedProfileStore,
SlavedAccountDataStore, SlavedAccountDataStore,
SlavedPusherStore, SlavedPusherStore,
@ -174,11 +174,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = EventCreatorServer( ss = EventCreatorServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -32,11 +32,16 @@ from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -49,11 +54,16 @@ logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore( class FederationReaderSlavedStore(
SlavedProfileStore,
SlavedApplicationServiceStore,
SlavedPusherStore,
SlavedPushRuleStore,
SlavedReceiptsStore,
SlavedEventStore, SlavedEventStore,
SlavedKeyStore, SlavedKeyStore,
RoomStore, RoomStore,
DirectoryStore, DirectoryStore,
TransactionStore, SlavedTransactionStore,
BaseSlavedStore, BaseSlavedStore,
): ):
pass pass
@ -143,11 +153,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FederationReaderServer( ss = FederationReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -36,11 +36,11 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -50,7 +50,7 @@ logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -186,11 +186,13 @@ def start(config_options):
config.send_federation = True config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = FederationSenderServer( ps = FederationSenderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -208,11 +208,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FrontendProxyServer( ss = FrontendProxyServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -303,8 +303,8 @@ class SynapseHomeServer(HomeServer):
# Gauges to expose monthly active user control metrics # Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
def setup(config_options): def setup(config_options):
@ -338,6 +338,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
@ -346,6 +347,7 @@ def setup(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
@ -530,7 +532,7 @@ def run(hs):
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().get_monthly_active_count() count = yield hs.get_datastore().get_monthly_active_count()
current_mau_gauge.set(float(count)) current_mau_gauge.set(float(count))
max_mau_value_gauge.set(float(hs.config.max_mau_value)) max_mau_gauge.set(float(hs.config.max_mau_value))
hs.get_datastore().initialise_reserved_users( hs.get_datastore().initialise_reserved_users(
hs.config.mau_limits_reserved_threepids hs.config.mau_limits_reserved_threepids

View File

@ -34,7 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -52,7 +52,7 @@ class MediaRepositorySlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedClientIpStore, SlavedClientIpStore,
TransactionStore, SlavedTransactionStore,
BaseSlavedStore, BaseSlavedStore,
MediaRepositoryStore, MediaRepositoryStore,
): ):
@ -155,11 +155,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = MediaRepositoryServer( ss = MediaRepositoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -214,11 +214,13 @@ def start(config_options):
config.update_user_directory = True config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = UserDirectoryServer( ps = UserDirectoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,

View File

@ -193,9 +193,8 @@ def setup_logging(config, use_worker_options=False):
def sighup(signum, stack): def sighup(signum, stack):
# it might be better to use a file watcher or something for this. # it might be better to use a file watcher or something for this.
logging.info("Reloading log config from %s due to SIGHUP",
log_config)
load_log_config() load_log_config()
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config() load_log_config()

View File

@ -82,6 +82,10 @@ class ServerConfig(Config):
self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "") self.hs_disabled_message = config.get("hs_disabled_message", "")
# Admin uri to direct users at should their instance become blocked
# due to resource constraints
self.admin_uri = config.get("admin_uri", None)
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(

View File

@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from zope.interface import implementer
from OpenSSL import SSL, crypto from OpenSSL import SSL, crypto
from twisted.internet import ssl
from twisted.internet._sslverify import _defaultCurveName from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ServerContextFactory(ssl.ContextFactory): class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming """Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections and to make connections to remote servers.""" connections."""
def __init__(self, config): def __init__(self, config):
self._context = SSL.Context(SSL.SSLv23_METHOD) self._context = SSL.Context(SSL.SSLv23_METHOD)
@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory):
def getContext(self): def getContext(self):
return self._context return self._context
def _idnaBytes(text):
"""
Convert some text typed by a human into some ASCII bytes. This is a
copy of twisted.internet._idna._idnaBytes. For documentation, see the
twisted documentation.
"""
try:
import idna
except ImportError:
return text.encode("idna")
else:
return idna.encode(text)
def _tolerateErrors(wrapped):
"""
Wrap up an info_callback for pyOpenSSL so that if something goes wrong
the error is immediately logged and the connection is dropped if possible.
This is a copy of twisted.internet._sslverify._tolerateErrors. For
documentation, see the twisted documentation.
"""
def infoCallback(connection, where, ret):
try:
return wrapped(connection, where, ret)
except: # noqa: E722, taken from the twisted implementation
f = Failure()
logger.exception("Error during info_callback")
connection.get_app_data().failVerification(f)
return infoCallback
@implementer(IOpenSSLClientConnectionCreator)
class ClientTLSOptions(object):
"""
Client creator for TLS without certificate identity verification. This is a
copy of twisted.internet._sslverify.ClientTLSOptions with the identity
verification left out. For documentation, see the twisted documentation.
"""
def __init__(self, hostname, ctx):
self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname)
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
def clientConnectionForTLS(self, tlsProtocol):
context = self._ctx
connection = SSL.Connection(context, None)
connection.set_app_data(tlsProtocol)
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START:
connection.set_tlsext_host_name(self._hostnameBytes)
class ClientTLSOptionsFactory(object):
"""Factory for Twisted ClientTLSOptions that are used to make connections
to remote servers for federation."""
def __init__(self, config):
# We don't use config options yet
pass
def get_options(self, host):
return ClientTLSOptions(
host.decode('utf-8'),
CertificateOptions(verify=False).getContext()
)

View File

@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path factory.path = path
factory.host = server_name factory.host = server_name
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, tls_client_options_factory, timeout=30
) )
for i in range(5): for i in range(5):

View File

@ -512,7 +512,7 @@ class Keyring(object):
continue continue
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory, server_name, self.hs.tls_client_options_factory,
path=(b"/_matrix/key/v2/server/%s" % ( path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.quote(requested_key_id),
)).encode("ascii"), )).encode("ascii"),
@ -655,7 +655,7 @@ class Keyring(object):
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory server_name, self.hs.tls_client_options_factory
) )
# Check the response. # Check the response.

View File

@ -39,8 +39,12 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_server_name
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import async from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -67,8 +71,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler self.handler = hs.get_handlers().federation_handler
self._server_linearizer = async.Linearizer("fed_server") self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler") self._transaction_linearizer = Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
@ -200,7 +204,7 @@ class FederationServer(FederationBase):
event_id, f.getTraceback().rstrip(), event_id, f.getTraceback().rstrip(),
) )
yield async.concurrently_execute( yield concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT, TRANSACTION_CONCURRENCY_LIMIT,
) )
@ -760,6 +764,8 @@ class FederationHandlerRegistry(object):
if edu_type in self.edu_handlers: if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,)) raise KeyError("Already have an EDU handler for %s" % (edu_type,))
logger.info("Registering federation EDU handler for %r", edu_type)
self.edu_handlers[edu_type] = handler self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler): def register_query_handler(self, query_type, handler):
@ -778,6 +784,8 @@ class FederationHandlerRegistry(object):
"Already have a Query handler for %s" % (query_type,) "Already have a Query handler for %s" % (query_type,)
) )
logger.info("Registering federation query handler for %r", query_type)
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
@defer.inlineCallbacks @defer.inlineCallbacks
@ -800,3 +808,49 @@ class FederationHandlerRegistry(object):
raise NotFoundError("No handler for Query type '%s'" % (query_type,)) raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args) return handler(args)
class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
"""A FederationHandlerRegistry for worker processes.
When receiving EDU or queries it will check if an appropriate handler has
been registered on the worker, if there isn't one then it calls off to the
master process.
"""
def __init__(self, hs):
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
super(ReplicationFederationHandlerRegistry, self).__init__()
def on_edu(self, edu_type, origin, content):
"""Overrides FederationHandlerRegistry
"""
handler = self.edu_handlers.get(edu_type)
if handler:
return super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content,
)
return self._send_edu(
edu_type=edu_type,
origin=origin,
content=content,
)
def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
return handler(args)
return self._get_query_client(
query_type=query_type,
args=args,
)

View File

@ -828,12 +828,26 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address): def delete_threepid(self, user_id, medium, address):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
Args:
user_id (str)
medium (str)
address (str)
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
unbind API.
"""
# 'Canonicalise' email addresses as per above # 'Canonicalise' email addresses as per above
if medium == 'email': if medium == 'email':
address = address.lower() address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
yield identity_handler.unbind_threepid( result = yield identity_handler.try_unbind_threepid(
user_id, user_id,
{ {
'medium': medium, 'medium': medium,
@ -841,10 +855,10 @@ class AuthHandler(BaseHandler):
}, },
) )
ret = yield self.store.user_delete_threepid( yield self.store.user_delete_threepid(
user_id, medium, address, user_id, medium, address,
) )
defer.returnValue(ret) defer.returnValue(result)
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage

View File

@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler):
erase_data (bool): whether to GDPR-erase the user's data erase_data (bool): whether to GDPR-erase the user's data
Returns: Returns:
Deferred Deferred[bool]: True if identity server supports removing
threepids, otherwise False.
""" """
# FIXME: Theoretically there is a race here wherein user resets # FIXME: Theoretically there is a race here wherein user resets
# password using threepid. # password using threepid.
@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler):
# leave the user still active so they can try again. # leave the user still active so they can try again.
# Ideally we would prevent password resets and then do this in the # Ideally we would prevent password resets and then do this in the
# background thread. # background thread.
# This will be set to false if the identity server doesn't support
# unbinding
identity_server_supports_unbinding = True
threepids = yield self.store.user_get_threepids(user_id) threepids = yield self.store.user_get_threepids(user_id)
for threepid in threepids: for threepid in threepids:
try: try:
yield self._identity_handler.unbind_threepid( result = yield self._identity_handler.try_unbind_threepid(
user_id, user_id,
{ {
'medium': threepid['medium'], 'medium': threepid['medium'],
'address': threepid['address'], 'address': threepid['address'],
}, },
) )
identity_server_supports_unbinding &= result
except Exception: except Exception:
# Do we want this to be a fatal error or should we carry on? # Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server") logger.exception("Failed to remove threepid from ID server")
@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
self._start_user_parting() self._start_user_parting()
defer.returnValue(identity_server_supports_unbinding)
def _start_user_parting(self): def _start_user_parting(self):
""" """
Start the process that goes through the table of users Start the process that goes through the table of users

View File

@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError from synapse.api.errors import FederationDeniedError
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination

View File

@ -49,10 +49,15 @@ from synapse.crypto.event_signing import (
compute_event_signature, compute_event_signature,
) )
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import resolve_events_with_factory from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -91,6 +96,18 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self._send_events_to_master = (
ReplicationFederationSendEventsRestServlet.make_client(hs)
)
self._notify_user_membership_change = (
ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
)
self._clean_room_for_join_client = (
ReplicationCleanRoomRestServlet.make_client(hs)
)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
@ -1158,7 +1175,7 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)]) yield self.persist_events_and_notify([(event, context)])
defer.returnValue(event) defer.returnValue(event)
@ -1189,7 +1206,7 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)]) yield self.persist_events_and_notify([(event, context)])
defer.returnValue(event) defer.returnValue(event)
@ -1432,7 +1449,7 @@ class FederationHandler(BaseHandler):
event, context event, context
) )
yield self._persist_events( yield self.persist_events_and_notify(
[(event, context)], [(event, context)],
backfilled=backfilled, backfilled=backfilled,
) )
@ -1470,7 +1487,7 @@ class FederationHandler(BaseHandler):
], consumeErrors=True, ], consumeErrors=True,
)) ))
yield self._persist_events( yield self.persist_events_and_notify(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
@ -1558,7 +1575,7 @@ class FederationHandler(BaseHandler):
raise raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self._persist_events( yield self.persist_events_and_notify(
[ [
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
@ -1569,7 +1586,7 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
yield self._persist_events( yield self.persist_events_and_notify(
[(event, new_event_context)], [(event, new_event_context)],
) )
@ -2297,7 +2314,7 @@ class FederationHandler(BaseHandler):
for revocation. for revocation.
""" """
try: try:
response = yield self.hs.get_simple_http_client().get_json( response = yield self.http_client.get_json(
url, url,
{"public_key": public_key} {"public_key": public_key}
) )
@ -2310,7 +2327,7 @@ class FederationHandler(BaseHandler):
raise AuthError(403, "Third party certificate was invalid") raise AuthError(403, "Third party certificate was invalid")
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_events(self, event_and_contexts, backfilled=False): def persist_events_and_notify(self, event_and_contexts, backfilled=False):
"""Persists events and tells the notifier/pushers about them, if """Persists events and tells the notifier/pushers about them, if
necessary. necessary.
@ -2322,14 +2339,21 @@ class FederationHandler(BaseHandler):
Returns: Returns:
Deferred Deferred
""" """
max_stream_id = yield self.store.persist_events( if self.config.worker_app:
event_and_contexts, yield self._send_events_to_master(
backfilled=backfilled, store=self.store,
) event_and_contexts=event_and_contexts,
backfilled=backfilled
)
else:
max_stream_id = yield self.store.persist_events(
event_and_contexts,
backfilled=backfilled,
)
if not backfilled: # Never notify for backfilled events if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts: for event, _ in event_and_contexts:
self._notify_persisted_event(event, max_stream_id) self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id): def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the """Checks to see if notifier/pushers should be notified about the
@ -2368,9 +2392,25 @@ class FederationHandler(BaseHandler):
) )
def _clean_room_for_join(self, room_id): def _clean_room_for_join(self, room_id):
return self.store.clean_room_for_join(room_id) """Called to clean up any data in DB for a given room, ready for the
server to join the room.
Args:
room_id (str)
"""
if self.config.worker_app:
return self._clean_room_for_join_client(room_id)
else:
return self.store.clean_room_for_join(room_id)
def user_joined_room(self, user, room_id): def user_joined_room(self, user, room_id):
"""Called when a new user has joined the room """Called when a new user has joined the room
""" """
return user_joined_room(self.distributor, user, room_id) if self.config.worker_app:
return self._notify_user_membership_change(
room_id=room_id,
user_id=user.to_string(),
change="joined",
)
else:
return user_joined_room(self.distributor, user, room_id)

View File

@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def unbind_threepid(self, mxid, threepid): def try_unbind_threepid(self, mxid, threepid):
""" """Removes a binding from an identity server
Removes a binding from an identity server
Args: Args:
mxid (str): Matrix user ID of binding to be removed mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be removed threepid (dict): Dict with medium & address of binding to be removed
Raises:
SynapseError: If we failed to contact the identity server
Returns: Returns:
Deferred[bool]: True on success, otherwise False Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding
""" """
logger.debug("unbinding threepid %r from %s", threepid, mxid) logger.debug("unbinding threepid %r from %s", threepid, mxid)
if not self.trusted_id_servers: if not self.trusted_id_servers:
@ -175,11 +179,21 @@ class IdentityHandler(BaseHandler):
content=content, content=content,
destination_is=id_server, destination_is=id_server,
) )
yield self.http_client.post_json_get_json( try:
url, yield self.http_client.post_json_get_json(
content, url,
headers, content,
) headers,
)
except HttpResponseException as e:
if e.code in (400, 404, 501,):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
defer.returnValue(False)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(502, "Failed to contact identity server")
defer.returnValue(True) defer.returnValue(True)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client

View File

@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func

View File

@ -22,7 +22,7 @@ from synapse.api.constants import Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client

View File

@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -95,6 +95,7 @@ class PresenceHandler(object):
Args: Args:
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
self.hs = hs
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -230,6 +231,10 @@ class PresenceHandler(object):
earlier than they should when synapse is restarted. This affect of this earlier than they should when synapse is restarted. This affect of this
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
logger.info( logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes", "Performing _on_shutdown. Persisting %d unpersisted changes",
len(self.user_to_current_state) len(self.user_to_current_state)

View File

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler from ._base import BaseHandler

View File

@ -28,7 +28,7 @@ from synapse.api.errors import (
) )
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler from ._base import BaseHandler
@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield self._check_mau_limits()
yield self.auth.check_auth_blocking()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
need_register = True need_register = True
try: try:
@ -533,14 +534,3 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
action="join", action="join",
) )
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Do not accept registrations if monthly active user limits exceeded
and limiting is enabled
"""
try:
yield self.auth.check_auth_blocking()
except AuthError as e:
raise RegistrationError(e.code, str(e), e.errcode)

View File

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache

View File

@ -30,7 +30,7 @@ import synapse.types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache

View File

@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint from synapse.http.endpoint import SpiderEndpoint
from synapse.util.async import add_timeout_to_deferred from synapse.util.async_helpers import add_timeout_to_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable

View File

@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination. # our record of an individual server which can be tried to reach a destination.
@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name):
return host, port return host, port
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
Args: Args:
reactor: Twisted reactor. reactor: Twisted reactor.
destination (bytes): The name of the server to connect to. destination (bytes): The name of the server to connect to.
ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory tls_client_options_factory
which generates SSL contexts to use for TLS. (synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds timeout (int): connection timeout in seconds
""" """
@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
if timeout is not None: if timeout is not None:
endpoint_kw_args.update(timeout=timeout) endpoint_kw_args.update(timeout=timeout)
if ssl_context_factory is None: if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint transport_endpoint = HostnameEndpoint
default_port = 8008 default_port = 8008
else: else:
def transport_endpoint(reactor, host, port, timeout): def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS( return wrapClientTLS(
ssl_context_factory, tls_client_options_factory.get_options(host),
HostnameEndpoint(reactor, host, port, timeout=timeout)) HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448 default_port = 8448

View File

@ -43,7 +43,7 @@ from synapse.api.errors import (
from synapse.http import cancelled_to_request_timed_out_error from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.async import add_timeout_to_deferred from synapse.util.async_helpers import add_timeout_to_deferred
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_server_context_factory = hs.tls_server_context_factory self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
ssl_context_factory=self.tls_server_context_factory tls_client_options_factory=self.tls_client_options_factory
) )

View File

@ -25,7 +25,7 @@ from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.util.async import ( from synapse.util.async_helpers import (
DeferredTimeoutError, DeferredTimeoutError,
ObservableDeferred, ObservableDeferred,
add_timeout_to_deferred, add_timeout_to_deferred,

View File

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@ -35,7 +35,7 @@ from synapse.push.presentable_names import (
name_from_member_event, name_from_member_event,
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import membership, send_event from synapse.replication.http import federation, membership, send_event
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs): def register_servlets(self, hs):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
federation.register_servlets(hs, self)

View File

@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and
notifying.
The API looks like:
POST /_synapse/replication/fed_send_events/:txn_id
{
"events": [{
"event": { .. serialized event .. },
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
}],
"backfilled": false
"""
NAME = "fed_send_events"
PATH_ARGS = ()
def __init__(self, hs):
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
@defer.inlineCallbacks
def _serialize_payload(store, event_and_contexts, backfilled):
"""
Args:
store
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
"""
event_payloads = []
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
event_payloads.append({
"event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
})
payload = {
"events": event_payloads,
"backfilled": backfilled,
}
defer.returnValue(payload)
@defer.inlineCallbacks
def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
backfilled = content["backfilled"]
event_payloads = content["events"]
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
context = yield EventContext.deserialize(
self.store, event_payload["context"],
)
event_and_contexts.append((event, context))
logger.info(
"Got %d events from federation",
len(event_and_contexts),
)
yield self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled,
)
defer.returnValue((200, {}))
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
"""Handles EDUs newly received from federation, including persisting and
notifying.
Request format:
POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
{
"origin": ...,
"content: { ... }
}
"""
NAME = "fed_send_edu"
PATH_ARGS = ("edu_type",)
def __init__(self, hs):
super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.registry = hs.get_federation_registry()
@staticmethod
def _serialize_payload(edu_type, origin, content):
return {
"origin": origin,
"content": content,
}
@defer.inlineCallbacks
def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
origin = content["origin"]
edu_content = content["content"]
logger.info(
"Got %r edu from $s",
edu_type, origin,
)
result = yield self.registry.on_edu(edu_type, origin, edu_content)
defer.returnValue((200, result))
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
"""Handle responding to queries from federation.
Request format:
POST /_synapse/replication/fed_query/:query_type
{
"args": { ... }
}
"""
NAME = "fed_query"
PATH_ARGS = ("query_type",)
# This is a query, so let's not bother caching
CACHE = False
def __init__(self, hs):
super(ReplicationGetQueryRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.registry = hs.get_federation_registry()
@staticmethod
def _serialize_payload(query_type, args):
"""
Args:
query_type (str)
args (dict): The arguments received for the given query type
"""
return {
"args": args,
}
@defer.inlineCallbacks
def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
args = content["args"]
logger.info(
"Got %r query",
query_type,
)
result = yield self.registry.on_query(query_type, args)
defer.returnValue((200, result))
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Request format:
POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
{}
"""
NAME = "fed_cleanup_room"
PATH_ARGS = ("room_id",)
def __init__(self, hs):
super(ReplicationCleanRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
@staticmethod
def _serialize_payload(room_id, args):
"""
Args:
room_id (str)
"""
return {}
@defer.inlineCallbacks
def _handle_request(self, request, room_id):
yield self.store.clean_room_for_join(room_id)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)

View File

@ -13,19 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore from synapse.storage.transactions import TransactionStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
class TransactionStore(BaseSlavedStore): class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[ pass
"get_destination_retry_timings"
]
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
prep_send_transaction = DataStore.prep_send_transaction.__func__
delivered_txn = DataStore.delivered_txn.__func__

View File

@ -17,7 +17,7 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self._deactivate_account_handler.deactivate_account( result = yield self._deactivate_account_handler.deactivate_account(
target_user_id, erase, target_user_id, erase,
) )
defer.returnValue((200, {})) if result:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
defer.returnValue((200, {
"id_server_unbind_result": id_server_unbind_result,
}))
class ShutdownRoomRestServlet(ClientV1RestServlet): class ShutdownRoomRestServlet(ClientV1RestServlet):

View File

@ -209,10 +209,17 @@ class DeactivateAccountRestServlet(RestServlet):
yield self.auth_handler.validate_user_via_ui_auth( yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),
) )
yield self._deactivate_account_handler.deactivate_account( result = yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, requester.user.to_string(), erase,
) )
defer.returnValue((200, {})) if result:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
defer.returnValue((200, {
"id_server_unbind_result": id_server_unbind_result,
}))
class EmailThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
@ -364,7 +371,7 @@ class ThreepidDeleteRestServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
yield self.auth_handler.delete_threepid( ret = yield self.auth_handler.delete_threepid(
user_id, body['medium'], body['address'] user_id, body['medium'], body['address']
) )
except Exception: except Exception:
@ -374,7 +381,14 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid") logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid")
defer.returnValue((200, {})) if ret:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
defer.returnValue((200, {
"id_server_unbind_result": id_server_unbind_result,
}))
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):

View File

@ -36,7 +36,7 @@ from synapse.api.errors import (
) )
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import is_ascii, random_string from synapse.util.stringutils import is_ascii, random_string

View File

@ -42,7 +42,7 @@ from synapse.http.server import (
) )
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import is_ascii, random_string from synapse.util.stringutils import is_ascii, random_string

View File

@ -36,6 +36,7 @@ from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import ( from synapse.federation.federation_server import (
FederationHandlerRegistry, FederationHandlerRegistry,
FederationServer, FederationServer,
ReplicationFederationHandlerRegistry,
) )
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
@ -423,7 +424,10 @@ class HomeServer(object):
return RoomMemberMasterHandler(self) return RoomMemberMasterHandler(self)
def build_federation_registry(self): def build_federation_registry(self):
return FederationHandlerRegistry() if self.config.worker_app:
return ReplicationFederationHandlerRegistry(self)
else:
return FederationHandlerRegistry()
def build_server_notices_manager(self): def build_server_notices_manager(self):
if self.config.worker_app: if self.config.worker_app:

View File

@ -28,7 +28,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function

View File

@ -4,7 +4,7 @@
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css"> <link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script> <script src="js/jquery-2.1.3.min.js"></script>
<script src="js/recaptcha_ajax.js"></script> <script src="https://www.google.com/recaptcha/api/js/recaptcha_ajax.js"></script>
<script src="register_config.js"></script> <script src="register_config.js"></script>
<script src="js/register.js"></script> <script src="js/register.js"></script>
</head> </head>

File diff suppressed because one or more lines are too long

View File

@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now) self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self): def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
def update(): def update():
to_update = self._batch_row_update to_update = self._batch_row_update
self._batch_row_update = {} self._batch_row_update = {}

View File

@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
@ -1435,88 +1435,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(event.event_id, event.redacts) (event.event_id, event.redacts)
) )
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self._simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
defer.returnValue(set(r["event_id"] for r in rows))
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
"""
results = set()
def have_seen_events_txn(txn, chunk):
sql = (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
% (",".join("?" * len(chunk)), )
)
txn.execute(sql, chunk)
for (event_id, ) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
[]):
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
Args:
event_ids (list[str])
Returns:
Deferred[dict[str, str|None):
Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps
to None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction("get_rejection_reasons", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_daily_messages(self): def count_daily_messages(self):
""" """

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry) self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry) defer.returnValue(cache_entry)
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self._simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
defer.returnValue(set(r["event_id"] for r in rows))
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
"""
results = set()
def have_seen_events_txn(txn, chunk):
sql = (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
% (",".join("?" * len(chunk)), )
)
txn.execute(sql, chunk)
for (event_id, ) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
[]):
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
Args:
event_ids (list[str])
Returns:
Deferred[dict[str, str|None):
Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps
to None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction("get_rejection_reasons", f)

View File

@ -46,7 +46,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
tp["medium"], tp["address"] tp["medium"], tp["address"]
) )
if user_id: if user_id:
self.upsert_monthly_active_user(user_id) yield self.upsert_monthly_active_user(user_id)
reserved_user_list.append(user_id) reserved_user_list.append(user_id)
else: else:
logger.warning( logger.warning(
@ -64,23 +64,27 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[] Deferred[]
""" """
def _reap_users(txn): def _reap_users(txn):
# Purge stale users
thirty_days_ago = ( thirty_days_ago = (
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
) )
# Purge stale users
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
questionmarks = '?' * len(self.reserved_users)
query_args = [thirty_days_ago] query_args = [thirty_days_ago]
query_args.extend(self.reserved_users) base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
sql = """ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
DELETE FROM monthly_active_users # when len(reserved_users) == 0. Works fine on sqlite.
WHERE timestamp < ? if len(self.reserved_users) > 0:
AND user_id NOT IN ({}) # questionmarks is a hack to overcome sqlite not supporting
""".format(','.join(questionmarks)) # tuples in 'WHERE IN %s'
questionmarks = '?' * len(self.reserved_users)
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
','.join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args) txn.execute(sql, query_args)
@ -93,16 +97,24 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# negative LIMIT values. So there is no way to write it that both can # negative LIMIT values. So there is no way to write it that both can
# support # support
query_args = [self.hs.config.max_mau_value] query_args = [self.hs.config.max_mau_value]
query_args.extend(self.reserved_users)
sql = """ base_sql = """
DELETE FROM monthly_active_users DELETE FROM monthly_active_users
WHERE user_id NOT IN ( WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC ORDER BY timestamp DESC
LIMIT ? LIMIT ?
) )
AND user_id NOT IN ({}) """
""".format(','.join(questionmarks)) # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
','.join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args) txn.execute(sql, query_args)
yield self.runInteraction("reap_monthly_active_users", _reap_users) yield self.runInteraction("reap_monthly_active_users", _reap_users)

View File

@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
class RoomWorkerStore(SQLBaseStore): class RoomWorkerStore(SQLBaseStore):
def get_room(self, room_id):
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
Returns:
A namedtuple containing the room information, or an empty list.
"""
return self._simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
desc="get_room",
allow_none=True,
)
def get_public_room_ids(self): def get_public_room_ids(self):
return self._simple_select_onecol( return self._simple_select_onecol(
table="rooms", table="rooms",
@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
def get_room(self, room_id):
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
Returns:
A namedtuple containing the room information, or an empty list.
"""
return self._simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
desc="get_room",
allow_none=True,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public): def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id): def set_room_is_public_txn(txn, next_id):

View File

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii from synapse.util.stringutils import to_ascii

View File

@ -188,62 +188,30 @@ class Linearizer(object):
# things blocked from executing. # things blocked from executing.
self.key_to_defer = {} self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key): def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
# propagated inside inlineCallbacks until Twisted 18.7)
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# If the number of things executing is greater than the maximum # If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items # then add a deferred to the list of blocked items
# When on of the things currently executing finishes it will callback # When one of the things currently executing finishes it will callback
# this item so that it can continue executing. # this item so that it can continue executing.
if entry[0] >= self.max_count: if entry[0] >= self.max_count:
new_defer = defer.Deferred() res = self._await_lock(key)
entry[1][new_defer] = 1
logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
)
try:
yield make_deferred_yieldable(new_defer)
except Exception as e:
if isinstance(e, CancelledError):
logger.info(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
)
# we just have to take ourselves back out of the queue.
del entry[1][new_defer]
raise
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
# relatively rapidly lead to stack exhaustion. This is essentially
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
#
# In order to break the cycle, we add a cheeky sleep(0) here to
# ensure that we fall back to the reactor between each iteration.
#
# (This needs to happen while we hold the lock, and the context manager's exit
# code must be synchronous, so this is the only sensible place.)
yield self._clock.sleep(0)
else: else:
logger.info( logger.info(
"Acquired uncontended linearizer lock %r for key %r", self.name, key, "Acquired uncontended linearizer lock %r for key %r", self.name, key,
) )
entry[0] += 1 entry[0] += 1
res = defer.succeed(None)
# once we successfully get the lock, we need to return a context manager which
# will release the lock.
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager(_):
try: try:
yield yield
finally: finally:
@ -264,7 +232,64 @@ class Linearizer(object):
# map. # map.
del self.key_to_defer[key] del self.key_to_defer[key]
defer.returnValue(_ctx_manager()) res.addCallback(_ctx_manager)
return res
def _await_lock(self, key):
"""Helper for queue: adds a deferred to the queue
Assumes that we've already checked that we've reached the limit of the number
of lock-holders we allow. Creates a new deferred which is added to the list, and
adds some management around cancellations.
Returns the deferred, which will callback once we have secured the lock.
"""
entry = self.key_to_defer[key]
logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
)
new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
def cb(_r):
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
# relatively rapidly lead to stack exhaustion. This is essentially
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
#
# In order to break the cycle, we add a cheeky sleep(0) here to
# ensure that we fall back to the reactor between each iteration.
#
# (This needs to happen while we hold the lock, and the context manager's exit
# code must be synchronous, so this is the only sensible place.)
return self._clock.sleep(0)
def eb(e):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.info(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
)
# we just have to take ourselves back out of the queue.
del entry[1][new_defer]
return e
new_defer.addCallbacks(cb, eb)
return new_defer
class ReadWriteLock(object): class ReadWriteLock(object):

View File

@ -25,7 +25,7 @@ from six import itervalues, string_types
from twisted.internet import defer from twisted.internet import defer
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry

View File

@ -16,7 +16,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.async import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
class SnapshotCache(object): class SnapshotCache(object):

View File

@ -526,7 +526,7 @@ _to_ignore = [
"synapse.util.logcontext", "synapse.util.logcontext",
"synapse.http.server", "synapse.http.server",
"synapse.storage._base", "synapse.storage._base",
"synapse.util.async", "synapse.util.async_helpers",
] ]

View File

@ -15,4 +15,7 @@
from twisted.trial import util from twisted.trial import util
from tests import utils
util.DEFAULT_TIMEOUT_DURATION = 10 util.DEFAULT_TIMEOUT_DURATION = 10
utils.setupdb()

View File

@ -34,13 +34,12 @@ class TestHandlers(object):
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.state_handler = Mock() self.state_handler = Mock()
self.store = Mock() self.store = Mock()
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
self.hs.get_datastore = Mock(return_value=self.store) self.hs.get_datastore = Mock(return_value=self.store)
self.hs.handlers = TestHandlers(self.hs) self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
@ -53,11 +52,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = { user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
"name": self.test_user,
"token_id": "ditto",
"device_id": "device",
}
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
@ -76,10 +71,7 @@ class AuthTestCase(unittest.TestCase):
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = { user_info = {"name": self.test_user, "token_id": "ditto"}
"name": self.test_user,
"token_id": "ditto",
}
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
@ -90,8 +82,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
ip_range_whitelist=None,
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
@ -106,8 +97,11 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self): def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet from netaddr import IPSet
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar",
url="a_url",
sender=self.test_user,
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -122,8 +116,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_bad_ip(self): def test_get_user_by_req_appservice_valid_token_bad_ip(self):
from netaddr import IPSet from netaddr import IPSet
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar",
url="a_url",
sender=self.test_user,
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -160,8 +157,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -174,15 +170,13 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals( self.assertEquals(
requester.user.to_string(), requester.user.to_string(), masquerading_user_id.decode('utf8')
masquerading_user_id.decode('utf8')
) )
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -201,17 +195,15 @@ class AuthTestCase(unittest.TestCase):
# TODO(danielwh): Remove this mock when we remove the # TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback. # get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value={ return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
"name": "@baldrick:matrix.org",
"device_id": "device",
}
) )
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@ -225,15 +217,14 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value={ self.store.get_user_by_id = Mock(return_value={"is_guest": True})
"is_guest": True,
})
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@ -257,7 +248,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -277,7 +269,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
@ -298,7 +291,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key + "wrong") key=self.hs.config.macaroon_secret_key + "wrong",
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -320,7 +314,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -347,7 +342,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -380,7 +376,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@ -401,9 +398,7 @@ class AuthTestCase(unittest.TestCase):
token = yield self.hs.handlers.auth_handler.issue_access_token( token = yield self.hs.handlers.auth_handler.issue_access_token(
USER_ID, "DEVICE" USER_ID, "DEVICE"
) )
self.store.add_access_token_to_user.assert_called_with( self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
USER_ID, token, "DEVICE"
)
def get_user(tok): def get_user(tok):
if token != tok: if token != tok:
@ -414,10 +409,9 @@ class AuthTestCase(unittest.TestCase):
"token_id": 1234, "token_id": 1234,
"device_id": "DEVICE", "device_id": "DEVICE",
} }
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock(return_value={ self.store.get_user_by_id = Mock(return_value={"is_guest": False})
"is_guest": False,
})
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})
@ -461,8 +455,11 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
self.assertEquals(e.exception.code, 403)
# Ensure does not throw an error # Ensure does not throw an error
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
@ -476,5 +473,6 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(AuthError) as e: with self.assertRaises(AuthError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
self.assertEquals(e.exception.code, 403) self.assertEquals(e.exception.code, 403)

View File

@ -38,7 +38,6 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase): class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
@ -47,6 +46,7 @@ class FilteringTestCase(unittest.TestCase):
self.mock_http_client.put_json = DeferredMockCallable() self.mock_http_client.put_json = DeferredMockCallable()
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup,
handlers=None, handlers=None,
http_client=self.mock_http_client, http_client=self.mock_http_client,
keyring=Mock(), keyring=Mock(),
@ -64,7 +64,7 @@ class FilteringTestCase(unittest.TestCase):
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"}, {"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}}, {"room": {"not_rooms": ["#foo:pik-test"]}},
{"presence": {"senders": ["@bar;pik.test.com"]}} {"presence": {"senders": ["@bar;pik.test.com"]}},
] ]
for filter in invalid_filters: for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error: with self.assertRaises(SynapseError) as check_filter_error:
@ -81,34 +81,34 @@ class FilteringTestCase(unittest.TestCase):
"include_leave": False, "include_leave": False,
"rooms": ["!dee:pik-test"], "rooms": ["!dee:pik-test"],
"not_rooms": ["!gee:pik-test"], "not_rooms": ["!gee:pik-test"],
"account_data": {"limit": 0, "types": ["*"]} "account_data": {"limit": 0, "types": ["*"]},
} }
}, },
{ {
"room": { "room": {
"state": { "state": {
"types": ["m.room.*"], "types": ["m.room.*"],
"not_rooms": ["!726s6s6q:example.com"] "not_rooms": ["!726s6s6q:example.com"],
}, },
"timeline": { "timeline": {
"limit": 10, "limit": 10,
"types": ["m.room.message"], "types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"], "not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"] "not_senders": ["@spam:example.com"],
}, },
"ephemeral": { "ephemeral": {
"types": ["m.receipt", "m.typing"], "types": ["m.receipt", "m.typing"],
"not_rooms": ["!726s6s6q:example.com"], "not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"] "not_senders": ["@spam:example.com"],
} },
}, },
"presence": { "presence": {
"types": ["m.presence"], "types": ["m.presence"],
"not_senders": ["@alice:example.com"] "not_senders": ["@alice:example.com"],
}, },
"event_format": "client", "event_format": "client",
"event_fields": ["type", "content", "sender"] "event_fields": ["type", "content", "sender"],
} },
] ]
for filter in valid_filters: for filter in valid_filters:
try: try:
@ -121,229 +121,131 @@ class FilteringTestCase(unittest.TestCase):
pass pass
def test_definition_types_works_with_literals(self): def test_definition_types_works_with_literals(self):
definition = { definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
"types": ["m.room.message", "org.matrix.foo.bar"] event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue( self.assertTrue(Filter(definition).check(event))
Filter(definition).check(event)
)
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self):
definition = { definition = {"types": ["m.*", "org.matrix.foo.bar"]}
"types": ["m.*", "org.matrix.foo.bar"] event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
} self.assertTrue(Filter(definition).check(event))
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
)
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self):
definition = { definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="now.for.something.completely.different", type="now.for.something.completely.different",
room_id="!foo:bar" room_id="!foo:bar",
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self):
definition = { definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
"not_types": ["m.room.message", "org.matrix.foo.bar"] event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
} self.assertFalse(Filter(definition).check(event))
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
)
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self):
definition = { definition = {"not_types": ["m.room.message", "org.matrix.*"]}
"not_types": ["m.room.message", "org.matrix.*"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
type="org.matrix.custom.event",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self):
definition = { definition = {"not_types": ["m.*", "org.*"]}
"not_types": ["m.*", "org.*"] event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
} self.assertTrue(Filter(definition).check(event))
event = MockEvent(
sender="@foo:bar",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
)
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self):
definition = { definition = {
"not_types": ["m.*", "org.*"], "not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"] "types": ["m.room.message", "m.room.topic"],
} }
event = MockEvent( event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
sender="@foo:bar", self.assertFalse(Filter(definition).check(event))
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
)
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self):
definition = { definition = {"senders": ["@flibble:wibble"]}
"senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self):
definition = { definition = {"senders": ["@flibble:wibble"]}
"senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self):
definition = { definition = {"not_senders": ["@flibble:wibble"]}
"not_senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self):
definition = { definition = {"not_senders": ["@flibble:wibble"]}
"not_senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self):
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"] "senders": ["@kermit:muppets", "@misspiggy:muppets"],
} }
event = MockEvent( event = MockEvent(
sender="@misspiggy:muppets", sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self):
definition = { definition = {"rooms": ["!secretbase:unknown"]}
"rooms": ["!secretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self):
definition = { definition = {"rooms": ["!secretbase:unknown"]}
"rooms": ["!secretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown",
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self):
definition = { definition = {"not_rooms": ["!anothersecretbase:unknown"]}
"not_rooms": ["!anothersecretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown",
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self):
definition = { definition = {"not_rooms": ["!secretbase:unknown"]}
"not_rooms": ["!secretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown",
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = { definition = {
"not_rooms": ["!secretbase:unknown"], "not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"] "rooms": ["!secretbase:unknown"],
} }
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_combined_event(self): def test_definition_combined_event(self):
definition = { definition = {
@ -352,16 +254,14 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@kermit:muppets", # yup sender="@kermit:muppets", # yup
type="m.room.message", # yup type="m.room.message", # yup
room_id="!stage:unknown" # yup room_id="!stage:unknown", # yup
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self):
definition = { definition = {
@ -370,16 +270,14 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@misspiggy:muppets", # nope sender="@misspiggy:muppets", # nope
type="m.room.message", # yup type="m.room.message", # yup
room_id="!stage:unknown" # yup room_id="!stage:unknown", # yup
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self):
definition = { definition = {
@ -388,16 +286,14 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@kermit:muppets", # yup sender="@kermit:muppets", # yup
type="m.room.message", # yup type="m.room.message", # yup
room_id="!piggyshouse:muppets" # nope room_id="!piggyshouse:muppets", # nope
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self):
definition = { definition = {
@ -406,37 +302,26 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@kermit:muppets", # yup sender="@kermit:muppets", # yup
type="muppets.misspiggy.kisses", # nope type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown" # yup room_id="!stage:unknown", # yup
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_presence_match(self): def test_filter_presence_match(self):
user_filter_json = { user_filter_json = {"presence": {"types": ["m.*"]}}
"presence": {
"types": ["m.*"]
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
) )
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
@ -444,15 +329,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_presence_no_match(self): def test_filter_presence_no_match(self):
user_filter_json = { user_filter_json = {"presence": {"types": ["m.*"]}}
"presence": {
"types": ["m.*"]
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_localpart=user_localpart + "2", user_filter=user_filter_json
user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
event_id="$asdasd:localhost", event_id="$asdasd:localhost",
@ -462,8 +342,7 @@ class FilteringTestCase(unittest.TestCase):
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart + "2", user_localpart=user_localpart + "2", filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
@ -471,27 +350,15 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_match(self): def test_filter_room_state_match(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
) )
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_room_state(events=events) results = user_filter.filter_room_state(events=events)
@ -499,27 +366,17 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
type="org.matrix.custom.event",
room_id="!foo:bar"
) )
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_room_state(events) results = user_filter.filter_room_state(events)
@ -543,45 +400,32 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
self.assertEquals(user_filter_json, ( self.assertEquals(
yield self.datastore.get_user_filter( user_filter_json,
user_localpart=user_localpart, (
filter_id=0, yield self.datastore.get_user_filter(
) user_localpart=user_localpart, filter_id=0
)) )
),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
) )
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
self.assertEquals(filter.get_filter_json(), user_filter_json) self.assertEquals(filter.get_filter_json(), user_filter_json)

View File

@ -4,17 +4,16 @@ from tests import unittest
class TestRatelimiter(unittest.TestCase): class TestRatelimiter(unittest.TestCase):
def test_allowed(self): def test_allowed(self):
limiter = Ratelimiter() limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1, user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
) )
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(10., time_allowed) self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1, user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
) )
self.assertFalse(allowed) self.assertFalse(allowed)
self.assertEquals(10., time_allowed) self.assertEquals(10., time_allowed)
@ -28,7 +27,7 @@ class TestRatelimiter(unittest.TestCase):
def test_pruning(self): def test_pruning(self):
limiter = Ratelimiter() limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.send_message(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1, user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
) )
self.assertIn("test_id_1", limiter.message_counts) self.assertIn("test_id_1", limiter.message_counts)

View File

@ -24,14 +24,10 @@ from tests import unittest
def _regex(regex, exclusive=True): def _regex(regex, exclusive=True):
return { return {"regex": re.compile(regex), "exclusive": exclusive}
"regex": re.compile(regex),
"exclusive": exclusive
}
class ApplicationServiceTestCase(unittest.TestCase): class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
@ -41,8 +37,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
namespaces={ namespaces={
ApplicationService.NS_USERS: [], ApplicationService.NS_USERS: [],
ApplicationService.NS_ROOMS: [], ApplicationService.NS_ROOMS: [],
ApplicationService.NS_ALIASES: [] ApplicationService.NS_ALIASES: [],
} },
) )
self.event = Mock( self.event = Mock(
type="m.something", room_id="!foo:bar", sender="@someone:somewhere" type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
@ -52,25 +48,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse((yield self.service.is_interested(self.event))) self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
@ -98,60 +88,47 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = [
"#irc_foobar:matrix.org", "#athing:matrix.org" "#irc_foobar:matrix.org",
"#athing:matrix.org",
] ]
self.store.get_users_in_room.return_value = [] self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested( self.assertTrue((yield self.service.is_interested(self.event, self.store)))
self.event, self.store
)))
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=False) _regex("#irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_alias( self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
"#irc_foobar:matrix.org"
))
def test_non_exclusive_room(self): def test_non_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=False) _regex("!irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_room( self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
"!irc_foobar:matrix.org"
))
def test_non_exclusive_user(self): def test_non_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=False) _regex("@irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_user( self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
"@irc_foobar:matrix.org"
))
def test_exclusive_alias(self): def test_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=True) _regex("#irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_alias( self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
"#irc_foobar:matrix.org"
))
def test_exclusive_user(self): def test_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=True) _regex("@irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_user( self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
"@irc_foobar:matrix.org"
))
def test_exclusive_room(self): def test_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=True) _regex("!irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_room( self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
"!irc_foobar:matrix.org"
))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_no_match(self): def test_regex_alias_no_match(self):
@ -159,47 +136,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = [
"#xmpp_foobar:matrix.org", "#athing:matrix.org" "#xmpp_foobar:matrix.org",
"#athing:matrix.org",
] ]
self.store.get_users_in_room.return_value = [] self.store.get_users_in_room.return_value = []
self.assertFalse((yield self.service.is_interested( self.assertFalse((yield self.service.is_interested(self.event, self.store)))
self.event, self.store
)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
self.store.get_users_in_room.return_value = [] self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested( self.assertTrue((yield self.service.is_interested(self.event, self.store)))
self.event, self.store
)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(self):
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.content = { self.event.content = {"membership": "invite"}
"membership": "invite"
}
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.store.get_users_in_room.return_value = [ self.store.get_users_in_room.return_value = [
"@alice:here", "@alice:here",
"@irc_fo:here", # AS user "@irc_fo:here", # AS user
@ -208,6 +174,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_aliases_for_room.return_value = [] self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested( self.assertTrue(
event=self.event, store=self.store (yield self.service.is_interested(event=self.event, store=self.store))
))) )

View File

@ -30,7 +30,6 @@ from ..utils import MockClock
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.store = Mock() self.store = Mock()
@ -38,8 +37,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.recoverer = Mock() self.recoverer = Mock()
self.recoverer_fn = Mock(return_value=self.recoverer) self.recoverer_fn = Mock(return_value=self.recoverer)
self.txnctrl = _TransactionController( self.txnctrl = _TransactionController(
clock=self.clock, store=self.store, as_api=self.as_api, clock=self.clock,
recoverer_fn=self.recoverer_fn store=self.store,
as_api=self.as_api,
recoverer_fn=self.recoverer_fn,
) )
def test_single_service_up_txn_sent(self): def test_single_service_up_txn_sent(self):
@ -54,9 +55,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP) return_value=defer.succeed(ApplicationServiceState.UP)
) )
txn.send = Mock(return_value=defer.succeed(True)) txn.send = Mock(return_value=defer.succeed(True))
self.store.create_appservice_txn = Mock( self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
return_value=defer.succeed(txn)
)
# actual call # actual call
self.txnctrl.send(service, events) self.txnctrl.send(service, events)
@ -77,9 +76,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock( self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.DOWN) return_value=defer.succeed(ApplicationServiceState.DOWN)
) )
self.store.create_appservice_txn = Mock( self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
return_value=defer.succeed(txn)
)
# actual call # actual call
self.txnctrl.send(service, events) self.txnctrl.send(service, events)
@ -104,9 +101,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
) )
self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
txn.send = Mock(return_value=defer.succeed(False)) # fails to send txn.send = Mock(return_value=defer.succeed(False)) # fails to send
self.store.create_appservice_txn = Mock( self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
return_value=defer.succeed(txn)
)
# actual call # actual call
self.txnctrl.send(service, events) self.txnctrl.send(service, events)
@ -124,7 +119,6 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.as_api = Mock() self.as_api = Mock()
@ -146,6 +140,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def take_txn(*args, **kwargs): def take_txn(*args, **kwargs):
return defer.succeed(txns.pop(0)) return defer.succeed(txns.pop(0))
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
self.recoverer.recover() self.recoverer.recover()
@ -171,6 +166,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
return defer.succeed(txns.pop(0)) return defer.succeed(txns.pop(0))
else: else:
return defer.succeed(txn) return defer.succeed(txn)
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
self.recoverer.recover() self.recoverer.recover()
@ -197,7 +193,6 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
@ -211,9 +206,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
self.txn_ctrl.send = Mock( self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
side_effect=lambda x, y: make_deferred_yieldable(d),
)
service = Mock(id=4) service = Mock(id=4)
event = Mock(event_id="first") event = Mock(event_id="first")
event2 = Mock(event_id="second") event2 = Mock(event_id="second")
@ -247,6 +240,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def do_send(x, y): def do_send(x, y):
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent # send events for different ASes and make sure they are sent

View File

@ -24,7 +24,6 @@ from tests import unittest
class ConfigGenerationTestCase(unittest.TestCase): class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
self.file = os.path.join(self.dir, "homeserver.yaml") self.file = os.path.join(self.dir, "homeserver.yaml")
@ -33,23 +32,30 @@ class ConfigGenerationTestCase(unittest.TestCase):
shutil.rmtree(self.dir) shutil.rmtree(self.dir)
def test_generate_config_generates_files(self): def test_generate_config_generates_files(self):
HomeServerConfig.load_or_generate_config("", [ HomeServerConfig.load_or_generate_config(
"--generate-config", "",
"-c", self.file, [
"--report-stats=yes", "--generate-config",
"-H", "lemurs.win" "-c",
]) self.file,
"--report-stats=yes",
"-H",
"lemurs.win",
],
)
self.assertSetEqual( self.assertSetEqual(
set([ set(
"homeserver.yaml", [
"lemurs.win.log.config", "homeserver.yaml",
"lemurs.win.signing.key", "lemurs.win.log.config",
"lemurs.win.tls.crt", "lemurs.win.signing.key",
"lemurs.win.tls.dh", "lemurs.win.tls.crt",
"lemurs.win.tls.key", "lemurs.win.tls.dh",
]), "lemurs.win.tls.key",
set(os.listdir(self.dir)) ]
),
set(os.listdir(self.dir)),
) )
self.assert_log_filename_is( self.assert_log_filename_is(

View File

@ -24,7 +24,6 @@ from tests import unittest
class ConfigLoadingTestCase(unittest.TestCase): class ConfigLoadingTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
print(self.dir) print(self.dir)
@ -43,15 +42,14 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_generates_and_loads_macaroon_secret_key(self): def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config() self.generate_config()
with open(self.file, with open(self.file, "r") as f:
"r") as f:
raw = yaml.load(f) raw = yaml.load(f)
self.assertIn("macaroon_secret_key", raw) self.assertIn("macaroon_secret_key", raw)
config = HomeServerConfig.load_config("", ["-c", self.file]) config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertTrue( self.assertTrue(
hasattr(config, "macaroon_secret_key"), hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key" "Want config to have attr macaroon_secret_key",
) )
if len(config.macaroon_secret_key) < 5: if len(config.macaroon_secret_key) < 5:
self.fail( self.fail(
@ -62,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
self.assertTrue( self.assertTrue(
hasattr(config, "macaroon_secret_key"), hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key" "Want config to have attr macaroon_secret_key",
) )
if len(config.macaroon_secret_key) < 5: if len(config.macaroon_secret_key) < 5:
self.fail( self.fail(
@ -80,10 +78,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_disable_registration(self): def test_disable_registration(self):
self.generate_config() self.generate_config()
self.add_lines_to_config([ self.add_lines_to_config(
"enable_registration: true", ["enable_registration: true", "disable_registration: true"]
"disable_registration: true", )
])
# Check that disable_registration clobbers enable_registration. # Check that disable_registration clobbers enable_registration.
config = HomeServerConfig.load_config("", ["-c", self.file]) config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertFalse(config.enable_registration) self.assertFalse(config.enable_registration)
@ -92,18 +89,23 @@ class ConfigLoadingTestCase(unittest.TestCase):
self.assertFalse(config.enable_registration) self.assertFalse(config.enable_registration)
# Check that either config value is clobbered by the command line. # Check that either config value is clobbered by the command line.
config = HomeServerConfig.load_or_generate_config("", [ config = HomeServerConfig.load_or_generate_config(
"-c", self.file, "--enable-registration" "", ["-c", self.file, "--enable-registration"]
]) )
self.assertTrue(config.enable_registration) self.assertTrue(config.enable_registration)
def generate_config(self): def generate_config(self):
HomeServerConfig.load_or_generate_config("", [ HomeServerConfig.load_or_generate_config(
"--generate-config", "",
"-c", self.file, [
"--report-stats=yes", "--generate-config",
"-H", "lemurs.win" "-c",
]) self.file,
"--report-stats=yes",
"-H",
"lemurs.win",
],
)
def generate_config_and_remove_lines_containing(self, needle): def generate_config_and_remove_lines_containing(self, needle):
self.generate_config() self.generate_config()

View File

@ -24,9 +24,7 @@ from tests import unittest
# Perform these tests using given secret key so we get entirely deterministic # Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against. # signatures output that we can test against.
SIGNING_KEY_SEED = decode_base64( SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
)
KEY_ALG = "ed25519" KEY_ALG = "ed25519"
KEY_VER = 1 KEY_VER = 1
@ -36,7 +34,6 @@ HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase): class EventSigningTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
self.signing_key.alg = KEY_ALG self.signing_key.alg = KEY_ALG
@ -51,7 +48,7 @@ class EventSigningTestCase(unittest.TestCase):
'signatures': {}, 'signatures': {},
'type': "X", 'type': "X",
'unsigned': {'age_ts': 1000000}, 'unsigned': {'age_ts': 1000000},
}, }
) )
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
@ -61,8 +58,7 @@ class EventSigningTestCase(unittest.TestCase):
self.assertTrue(hasattr(event, 'hashes')) self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes) self.assertIn('sha256', event.hashes)
self.assertEquals( self.assertEquals(
event.hashes['sha256'], event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
) )
self.assertTrue(hasattr(event, 'signatures')) self.assertTrue(hasattr(event, 'signatures'))
@ -77,9 +73,7 @@ class EventSigningTestCase(unittest.TestCase):
def test_sign_message(self): def test_sign_message(self):
builder = EventBuilder( builder = EventBuilder(
{ {
'content': { 'content': {'body': "Here is the message content"},
'body': "Here is the message content",
},
'event_id': "$0:domain", 'event_id': "$0:domain",
'origin': "domain", 'origin': "domain",
'origin_server_ts': 1000000, 'origin_server_ts': 1000000,
@ -98,8 +92,7 @@ class EventSigningTestCase(unittest.TestCase):
self.assertTrue(hasattr(event, 'hashes')) self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes) self.assertIn('sha256', event.hashes)
self.assertEquals( self.assertEquals(
event.hashes['sha256'], event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
) )
self.assertTrue(hasattr(event, 'signatures')) self.assertTrue(hasattr(event, 'signatures'))
@ -108,5 +101,5 @@ class EventSigningTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME], event.signatures[HOSTNAME][KEY_NAME],
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA" "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA",
) )

View File

@ -36,9 +36,7 @@ class MockPerspectiveServer(object):
def get_verify_keys(self): def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key) vk = signedjson.key.get_verify_key(self.key)
return { return {"%s:%s" % (vk.alg, vk.version): vk}
"%s:%s" % (vk.alg, vk.version): vk,
}
def get_signed_key(self, server_name, verify_key): def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version) key_id = "%s:%s" % (verify_key.alg, verify_key.version)
@ -47,10 +45,8 @@ class MockPerspectiveServer(object):
"old_verify_keys": {}, "old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600, "valid_until_ts": time.time() * 1000 + 3600,
"verify_keys": { "verify_keys": {
key_id: { key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
"key": signedjson.key.encode_verify_key_base64(verify_key) },
}
}
} }
signedjson.sign.sign_json(res, self.server_name, self.key) signedjson.sign.sign_json(res, self.server_name, self.key)
return res return res
@ -62,18 +58,14 @@ class KeyringTestCase(unittest.TestCase):
self.mock_perspective_server = MockPerspectiveServer() self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock() self.http_client = Mock()
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, self.addCleanup, handlers=None, http_client=self.http_client
http_client=self.http_client,
) )
self.hs.config.perspectives = { keys = self.mock_perspective_server.get_verify_keys()
self.mock_perspective_server.server_name: self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
self.mock_perspective_server.get_verify_keys()
}
def check_context(self, _, expected): def check_context(self, _, expected):
self.assertEquals( self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), getattr(LoggingContext.current_context(), "request", None), expected
expected
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -89,8 +81,7 @@ class KeyringTestCase(unittest.TestCase):
context_one.request = "one" context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups( wait_1_deferred = kr.wait_for_previous_lookups(
["server1"], ["server1"], {"server1": lookup_1_deferred}
{"server1": lookup_1_deferred},
) )
# there were no previous lookups, so the deferred should be ready # there were no previous lookups, so the deferred should be ready
@ -105,8 +96,7 @@ class KeyringTestCase(unittest.TestCase):
# set off another wait. It should block because the first lookup # set off another wait. It should block because the first lookup
# hasn't yet completed. # hasn't yet completed.
wait_2_deferred = kr.wait_for_previous_lookups( wait_2_deferred = kr.wait_for_previous_lookups(
["server1"], ["server1"], {"server1": lookup_2_deferred}
{"server1": lookup_2_deferred},
) )
self.assertFalse(wait_2_deferred.called) self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext. # ... so we should have reset the LoggingContext.
@ -132,21 +122,19 @@ class KeyringTestCase(unittest.TestCase):
persp_resp = { persp_resp = {
"server_keys": [ "server_keys": [
self.mock_perspective_server.get_signed_key( self.mock_perspective_server.get_signed_key(
"server10", "server10", signedjson.key.get_verify_key(key1)
signedjson.key.get_verify_key(key1) )
),
] ]
} }
persp_deferred = defer.Deferred() persp_deferred = defer.Deferred()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_perspectives(**kwargs): def get_perspectives(**kwargs):
self.assertEquals( self.assertEquals(LoggingContext.current_context().request, "11")
LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
yield persp_deferred yield persp_deferred
defer.returnValue(persp_resp) defer.returnValue(persp_resp)
self.http_client.post_json.side_effect = get_perspectives self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11: with LoggingContext("11") as context_11:
@ -154,9 +142,7 @@ class KeyringTestCase(unittest.TestCase):
# start off a first set of lookups # start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), [("server10", json1), ("server11", {})]
("server11", {})
]
) )
# the unsigned json should be rejected pretty quickly # the unsigned json should be rejected pretty quickly
@ -172,7 +158,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server # wait a tick for it to send the request to the perspectives server
# (it first tries the datastore) # (it first tries the datastore)
yield clock.sleep(1) # XXX find out why this takes so long! yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once() self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11) self.assertIs(LoggingContext.current_context(), context_11)
@ -186,7 +172,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.return_value = defer.Deferred() self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)], [("server10", json1)]
) )
yield clock.sleep(1) yield clock.sleep(1)
self.http_client.post_json.assert_not_called() self.http_client.post_json.assert_not_called()
@ -207,8 +193,7 @@ class KeyringTestCase(unittest.TestCase):
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
yield self.hs.datastore.store_server_verify_key( yield self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
signedjson.key.get_verify_key(key1),
) )
json1 = {} json1 = {}
signedjson.sign.sign_json(json1, "server9", key1) signedjson.sign.sign_json(json1, "server9", key1)

View File

@ -31,25 +31,20 @@ def MockEvent(**kwargs):
class PruneEventTestCase(unittest.TestCase): class PruneEventTestCase(unittest.TestCase):
""" Asserts that a new event constructed with `evdict` will look like """ Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted. """ `matchdict` when it is redacted. """
def run_test(self, evdict, matchdict): def run_test(self, evdict, matchdict):
self.assertEquals( self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict)
prune_event(FrozenEvent(evdict)).get_dict(),
matchdict
)
def test_minimal(self): def test_minimal(self):
self.run_test( self.run_test(
{ {'type': 'A', 'event_id': '$test:domain'},
'type': 'A',
'event_id': '$test:domain',
},
{ {
'type': 'A', 'type': 'A',
'event_id': '$test:domain', 'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
def test_basic_keys(self): def test_basic_keys(self):
@ -70,23 +65,19 @@ class PruneEventTestCase(unittest.TestCase):
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
def test_unsigned_age_ts(self): def test_unsigned_age_ts(self):
self.run_test( self.run_test(
{ {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}},
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20},
},
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain', 'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
} },
) )
self.run_test( self.run_test(
@ -101,23 +92,19 @@ class PruneEventTestCase(unittest.TestCase):
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
def test_content(self): def test_content(self):
self.run_test( self.run_test(
{ {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}},
'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'},
},
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain', 'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
self.run_test( self.run_test(
@ -132,27 +119,20 @@ class PruneEventTestCase(unittest.TestCase):
'content': {'creator': '@2:domain'}, 'content': {'creator': '@2:domain'},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
class SerializeEventTestCase(unittest.TestCase): class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields): def serialize(self, ev, fields):
return serialize_event(ev, 1479807801915, only_event_fields=fields) return serialize_event(ev, 1479807801915, only_event_fields=fields)
def test_event_fields_works_with_keys(self): def test_event_fields_works_with_keys(self):
self.assertEquals( self.assertEquals(
self.serialize( self.serialize(
MockEvent( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
sender="@alice:localhost",
room_id="!foo:bar"
),
["room_id"]
), ),
{ {"room_id": "!foo:bar"},
"room_id": "!foo:bar",
}
) )
def test_event_fields_works_with_nested_keys(self): def test_event_fields_works_with_nested_keys(self):
@ -161,17 +141,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"body": "A message"},
"body": "A message",
},
), ),
["content.body"] ["content.body"],
), ),
{ {"content": {"body": "A message"}},
"content": {
"body": "A message",
}
}
) )
def test_event_fields_works_with_dot_keys(self): def test_event_fields_works_with_dot_keys(self):
@ -180,17 +154,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"key.with.dots": {}},
"key.with.dots": {},
},
), ),
["content.key\.with\.dots"] ["content.key\.with\.dots"],
), ),
{ {"content": {"key.with.dots": {}}},
"content": {
"key.with.dots": {},
}
}
) )
def test_event_fields_works_with_nested_dot_keys(self): def test_event_fields_works_with_nested_dot_keys(self):
@ -201,21 +169,12 @@ class SerializeEventTestCase(unittest.TestCase):
room_id="!foo:bar", room_id="!foo:bar",
content={ content={
"not_me": 1, "not_me": 1,
"nested.dot.key": { "nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
"leaf.key": 42,
"not_me_either": 1,
},
}, },
), ),
["content.nested\.dot\.key.leaf\.key"] ["content.nested\.dot\.key.leaf\.key"],
), ),
{ {"content": {"nested.dot.key": {"leaf.key": 42}}},
"content": {
"nested.dot.key": {
"leaf.key": 42,
},
}
}
) )
def test_event_fields_nops_with_unknown_keys(self): def test_event_fields_nops_with_unknown_keys(self):
@ -224,17 +183,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": "bar"},
"foo": "bar",
},
), ),
["content.foo", "content.notexists"] ["content.foo", "content.notexists"],
), ),
{ {"content": {"foo": "bar"}},
"content": {
"foo": "bar",
}
}
) )
def test_event_fields_nops_with_non_dict_keys(self): def test_event_fields_nops_with_non_dict_keys(self):
@ -243,13 +196,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": ["I", "am", "an", "array"]},
"foo": ["I", "am", "an", "array"],
},
), ),
["content.foo.am"] ["content.foo.am"],
), ),
{} {},
) )
def test_event_fields_nops_with_array_keys(self): def test_event_fields_nops_with_array_keys(self):
@ -258,13 +209,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": ["I", "am", "an", "array"]},
"foo": ["I", "am", "an", "array"],
},
), ),
["content.foo.1"] ["content.foo.1"],
), ),
{} {},
) )
def test_event_fields_all_fields_if_empty(self): def test_event_fields_all_fields_if_empty(self):
@ -274,31 +223,21 @@ class SerializeEventTestCase(unittest.TestCase):
type="foo", type="foo",
event_id="test", event_id="test",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": "bar"},
"foo": "bar",
},
), ),
[] [],
), ),
{ {
"type": "foo", "type": "foo",
"event_id": "test", "event_id": "test",
"room_id": "!foo:bar", "room_id": "!foo:bar",
"content": { "content": {"foo": "bar"},
"foo": "bar", "unsigned": {},
}, },
"unsigned": {}
}
) )
def test_event_fields_fail_if_fields_not_str(self): def test_event_fields_fail_if_fields_not_str(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.serialize( self.serialize(
MockEvent( MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
room_id="!foo:bar",
content={
"foo": "bar",
},
),
["room_id", 4]
) )

View File

@ -23,10 +23,7 @@ from tests import unittest
@unittest.DEBUG @unittest.DEBUG
class ServerACLsTestCase(unittest.TestCase): class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self): def test_blacklisted_server(self):
e = _create_acl_event({ e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
"allow": ["*"],
"deny": ["evil.com"],
})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("evil.com", e)) self.assertFalse(server_matches_acl_event("evil.com", e))
@ -36,10 +33,7 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
def test_block_ip_literals(self): def test_block_ip_literals(self):
e = _create_acl_event({ e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
"allow_ip_literals": False,
"allow": ["*"],
})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("1.2.3.4", e)) self.assertFalse(server_matches_acl_event("1.2.3.4", e))
@ -49,10 +43,12 @@ class ServerACLsTestCase(unittest.TestCase):
def _create_acl_event(content): def _create_acl_event(content):
return FrozenEvent({ return FrozenEvent(
"room_id": "!a:b", {
"event_id": "$a:b", "room_id": "!a:b",
"type": "m.room.server_acls", "event_id": "$a:b",
"sender": "@a:b", "type": "m.room.server_acls",
"content": content "sender": "@a:b",
}) "content": content,
}
)

View File

@ -45,20 +45,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [ services = [
self._mkservice(is_interested=False), self._mkservice(is_interested=False),
interested_service, interested_service,
self._mkservice(is_interested=False) self._mkservice(is_interested=False),
] ]
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=[]) self.mock_store.get_user_by_id = Mock(return_value=[])
event = Mock( event = Mock(
sender="@someone:anywhere", sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
type="m.room.message",
room_id="!foo:bar"
) )
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]), (0, [event]),
(0, []) (0, []),
] ]
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
@ -74,21 +72,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=None) self.mock_store.get_user_by_id = Mock(return_value=None)
event = Mock( event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]), (0, [event]),
(0, []) (0, []),
] ]
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with( self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
services[0], user_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_query_user_exists_known_user(self): def test_query_user_exists_known_user(self):
@ -96,25 +88,19 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)] services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True) services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value={ self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
"name": user_id
})
event = Mock( event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]), (0, [event]),
(0, []) (0, []),
] ]
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.assertFalse( self.assertFalse(
self.mock_as_api.query_user.called, self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been." "query_user called when it shouldn't have been.",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -129,7 +115,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [ services = [
self._mkservice_alias(is_interested_in_alias=False), self._mkservice_alias(is_interested_in_alias=False),
interested_service, interested_service,
self._mkservice_alias(is_interested_in_alias=False) self._mkservice_alias(is_interested_in_alias=False),
] ]
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
@ -140,8 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
result = yield self.handler.query_room_alias_exists(room_alias) result = yield self.handler.query_room_alias_exists(room_alias)
self.mock_as_api.query_alias.assert_called_once_with( self.mock_as_api.query_alias.assert_called_once_with(
interested_service, interested_service, room_alias_str
room_alias_str
) )
self.assertEquals(result.room_id, room_id) self.assertEquals(result.room_id, room_id)
self.assertEquals(result.servers, servers) self.assertEquals(result.servers, servers)

View File

@ -35,7 +35,7 @@ class AuthHandlers(object):
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator() self.macaroon_generator = self.hs.get_macaroon_generator()
@ -81,9 +81,7 @@ class AuthTestCase(unittest.TestCase):
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
"a_user", 5000
)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
@ -98,17 +96,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
"a_user", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
self.assertEqual( self.assertEqual("a_user", user_id)
"a_user", user_id
)
# add another "user_id" caveat, which might allow us to override the # add another "user_id" caveat, which might allow us to override the
# user_id. # user_id.
@ -165,7 +159,5 @@ class AuthTestCase(unittest.TestCase):
) )
def _get_macaroon(self): def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
"user_a", 5000
)
return pymacaroons.Macaroon.deserialize(token) return pymacaroons.Macaroon.deserialize(token)

View File

@ -28,13 +28,13 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.TestCase): class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs) super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore self.store = None # type: synapse.storage.DataStore
self.handler = None # type: synapse.handlers.device.DeviceHandler self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock self.clock = None # type: utils.MockClock
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield utils.setup_test_homeserver() hs = yield utils.setup_test_homeserver(self.addCleanup)
self.handler = hs.get_device_handler() self.handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -44,7 +44,7 @@ class DeviceTestCase(unittest.TestCase):
res = yield self.handler.check_device_registered( res = yield self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name",
) )
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
@ -56,14 +56,14 @@ class DeviceTestCase(unittest.TestCase):
res1 = yield self.handler.check_device_registered( res1 = yield self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name",
) )
self.assertEqual(res1, "fco") self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered( res2 = yield self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="new display name" initial_device_display_name="new display name",
) )
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
@ -75,7 +75,7 @@ class DeviceTestCase(unittest.TestCase):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id="@theresa:foo", user_id="@theresa:foo",
device_id=None, device_id=None,
initial_device_display_name="display" initial_device_display_name="display",
) )
dev = yield self.handler.store.get_device("@theresa:foo", device_id) dev = yield self.handler.store.get_device("@theresa:foo", device_id)
@ -87,43 +87,53 @@ class DeviceTestCase(unittest.TestCase):
res = yield self.handler.get_devices_by_user(user1) res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res)) self.assertEqual(3, len(res))
device_map = { device_map = {d["device_id"]: d for d in res}
d["device_id"]: d for d in res self.assertDictContainsSubset(
} {
self.assertDictContainsSubset({ "user_id": user1,
"user_id": user1, "device_id": "xyz",
"device_id": "xyz", "display_name": "display 0",
"display_name": "display 0", "last_seen_ip": None,
"last_seen_ip": None, "last_seen_ts": None,
"last_seen_ts": None, },
}, device_map["xyz"]) device_map["xyz"],
self.assertDictContainsSubset({ )
"user_id": user1, self.assertDictContainsSubset(
"device_id": "fco", {
"display_name": "display 1", "user_id": user1,
"last_seen_ip": "ip1", "device_id": "fco",
"last_seen_ts": 1000000, "display_name": "display 1",
}, device_map["fco"]) "last_seen_ip": "ip1",
self.assertDictContainsSubset({ "last_seen_ts": 1000000,
"user_id": user1, },
"device_id": "abc", device_map["fco"],
"display_name": "display 2", )
"last_seen_ip": "ip3", self.assertDictContainsSubset(
"last_seen_ts": 3000000, {
}, device_map["abc"]) "user_id": user1,
"device_id": "abc",
"display_name": "display 2",
"last_seen_ip": "ip3",
"last_seen_ts": 3000000,
},
device_map["abc"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_device(self): def test_get_device(self):
yield self._record_users() yield self._record_users()
res = yield self.handler.get_device(user1, "abc") res = yield self.handler.get_device(user1, "abc")
self.assertDictContainsSubset({ self.assertDictContainsSubset(
"user_id": user1, {
"device_id": "abc", "user_id": user1,
"display_name": "display 2", "device_id": "abc",
"last_seen_ip": "ip3", "display_name": "display 2",
"last_seen_ts": 3000000, "last_seen_ip": "ip3",
}, res) "last_seen_ts": 3000000,
},
res,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_device(self): def test_delete_device(self):
@ -153,8 +163,7 @@ class DeviceTestCase(unittest.TestCase):
def test_update_unknown_device(self): def test_update_unknown_device(self):
update = {"display_name": "new_display"} update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError): with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id", yield self.handler.update_device("user_id", "unknown_device_id", update)
update)
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_users(self): def _record_users(self):
@ -168,16 +177,17 @@ class DeviceTestCase(unittest.TestCase):
yield self._record_user(user2, "def", "dispkay", "token4", "ip4") yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_user(self, user_id, device_id, display_name, def _record_user(
access_token=None, ip=None): self, user_id, device_id, display_name, access_token=None, ip=None
):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=display_name initial_device_display_name=display_name,
) )
if ip is not None: if ip is not None:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, user_id, access_token, ip, "user_agent", device_id
access_token, ip, "user_agent", device_id) )
self.clock.advance_time(1000) self.clock.advance_time(1000)

View File

@ -42,9 +42,11 @@ class DirectoryTestCase(unittest.TestCase):
def register_query_handler(query_type, handler): def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup,
http_client=None, http_client=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
federation_client=self.mock_federation, federation_client=self.mock_federation,
@ -68,10 +70,7 @@ class DirectoryTestCase(unittest.TestCase):
result = yield self.handler.get_association(self.my_room) result = yield self.handler.get_association(self.my_room)
self.assertEquals({ self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
"room_id": "!8765qwer:test",
"servers": ["test"],
}, result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_remote_association(self): def test_get_remote_association(self):
@ -81,16 +80,13 @@ class DirectoryTestCase(unittest.TestCase):
result = yield self.handler.get_association(self.remote_room) result = yield self.handler.get_association(self.remote_room)
self.assertEquals({ self.assertEquals(
"room_id": "!8765qwer:test", {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
"servers": ["test", "remote"], )
}, result)
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
destination="remote", destination="remote",
query_type="directory", query_type="directory",
args={ args={"room_alias": "#another:remote"},
"room_alias": "#another:remote",
},
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True, ignore_backoff=True,
) )
@ -105,7 +101,4 @@ class DirectoryTestCase(unittest.TestCase):
{"room_alias": "#your-room:test"} {"room_alias": "#your-room:test"}
) )
self.assertEquals({ self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
"room_id": "!8765asdf:test",
"servers": ["test"],
}, response)

View File

@ -28,14 +28,13 @@ from tests import unittest, utils
class E2eKeysHandlerTestCase(unittest.TestCase): class E2eKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, self.addCleanup, handlers=None, federation_client=mock.Mock()
federation_client=mock.Mock(),
) )
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
@ -54,30 +53,21 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz" device_id = "xyz"
keys = { keys = {
"alg1:k1": "key1", "alg1:k1": "key1",
"alg2:k2": { "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"key": "key2", "alg2:k3": {"key": "key3"},
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
} }
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
# we should be able to change the signature without a problem # we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2" keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_change_one_time_keys(self): def test_change_one_time_keys(self):
@ -87,25 +77,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz" device_id = "xyz"
keys = { keys = {
"alg1:k1": "key1", "alg1:k1": "key1",
"alg2:k2": { "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"key": "key2", "alg2:k3": {"key": "key3"},
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
} }
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}, local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
) )
self.fail("No error when changing string key") self.fail("No error when changing string key")
except errors.SynapseError: except errors.SynapseError:
@ -113,7 +96,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}, local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
) )
self.fail("No error when replacing dict key with string") self.fail("No error when replacing dict key with string")
except errors.SynapseError: except errors.SynapseError:
@ -121,9 +104,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, { local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
"one_time_keys": {"alg1:k1": {"key": "key"}}
},
) )
self.fail("No error when replacing string key with dict") self.fail("No error when replacing string key with dict")
except errors.SynapseError: except errors.SynapseError:
@ -131,13 +112,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, { local_user,
device_id,
{
"one_time_keys": { "one_time_keys": {
"alg2:k2": { "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
"key": "key3", }
"signatures": {"k1": "sig1"},
}
},
}, },
) )
self.fail("No error when replacing dict key") self.fail("No error when replacing dict key")
@ -148,31 +128,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
def test_claim_one_time_key(self): def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = { keys = {"alg1:k1": "key1"}
"alg1:k1": "key1",
}
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
"one_time_key_counts": {"alg1": 1}
})
res2 = yield self.handler.claim_one_time_keys({ res2 = yield self.handler.claim_one_time_keys(
"one_time_keys": { {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
local_user: { )
device_id: "alg1" self.assertEqual(
} res2,
} {
}, timeout=None) "failures": {},
self.assertEqual(res2, { "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
"failures": {}, },
"one_time_keys": { )
local_user: {
device_id: {
"alg1:k1": "key1"
}
}
}
})

View File

@ -39,8 +39,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
@ -54,23 +53,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3) self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
call( [
now=now, call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
obj=user_id, call(
then=new_state.last_active_ts + IDLE_TIMER now=now,
), obj=user_id,
call( then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
now=now, ),
obj=user_id, call(
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT now=now,
), obj=user_id,
call( then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
now=now, ),
obj=user_id, ],
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY any_order=True,
), )
], any_order=True)
def test_online_to_online(self): def test_online_to_online(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -79,14 +77,11 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
last_active_ts=now,
currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
@ -101,23 +96,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3) self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
call( [
now=now, call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
obj=user_id, call(
then=new_state.last_active_ts + IDLE_TIMER now=now,
), obj=user_id,
call( then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
now=now, ),
obj=user_id, call(
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT now=now,
), obj=user_id,
call( then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
now=now, ),
obj=user_id, ],
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY any_order=True,
), )
], any_order=True)
def test_online_to_online_last_active_noop(self): def test_online_to_online_last_active_noop(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -132,8 +126,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
@ -148,23 +141,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3) self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
call( [
now=now, call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
obj=user_id, call(
then=new_state.last_active_ts + IDLE_TIMER now=now,
), obj=user_id,
call( then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
now=now, ),
obj=user_id, call(
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT now=now,
), obj=user_id,
call( then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
now=now, ),
obj=user_id, ],
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY any_order=True,
), )
], any_order=True)
def test_online_to_online_last_active(self): def test_online_to_online_last_active(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -178,9 +170,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
currently_active=True, currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
state=PresenceState.ONLINE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
@ -193,18 +183,17 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 2) self.assertEquals(wheel_timer.insert.call_count, 2)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
call( [
now=now, call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
obj=user_id, call(
then=new_state.last_active_ts + IDLE_TIMER now=now,
), obj=user_id,
call( then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
now=now, ),
obj=user_id, ],
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT any_order=True,
) )
], any_order=True)
def test_remote_ping_timer(self): def test_remote_ping_timer(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -213,13 +202,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
state=PresenceState.ONLINE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now
@ -232,13 +218,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(new_state.status_msg, state.status_msg) self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(wheel_timer.insert.call_count, 1) self.assertEquals(wheel_timer.insert.call_count, 1)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
call( [
now=now, call(
obj=user_id, now=now,
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT obj=user_id,
), then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
], any_order=True) )
],
any_order=True,
)
def test_online_to_offline(self): def test_online_to_offline(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -247,14 +236,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
last_active_ts=now,
currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.OFFLINE)
state=PresenceState.OFFLINE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
@ -273,14 +258,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
last_active_ts=now,
currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.UNAVAILABLE)
state=PresenceState.UNAVAILABLE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
@ -293,13 +274,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(new_state.status_msg, state.status_msg) self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(wheel_timer.insert.call_count, 1) self.assertEquals(wheel_timer.insert.call_count, 1)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
call( [
now=now, call(
obj=user_id, now=now,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT obj=user_id,
) then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
], any_order=True) )
],
any_order=True,
)
class PresenceTimeoutTestCase(unittest.TestCase): class PresenceTimeoutTestCase(unittest.TestCase):
@ -314,9 +298,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now, last_user_sync_ts=now,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
@ -332,9 +314,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.OFFLINE) self.assertEquals(new_state.state, PresenceState.OFFLINE)
@ -369,9 +349,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(new_state, new_state) self.assertEquals(new_state, new_state)
@ -388,9 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now, last_federation_update_ts=now,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNone(new_state) self.assertIsNone(new_state)
@ -425,9 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now, last_federation_update_ts=now,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(state, new_state) self.assertEquals(state, new_state)

View File

@ -48,15 +48,14 @@ class ProfileTestCase(unittest.TestCase):
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup,
http_client=None, http_client=None,
handlers=None, handlers=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
federation_client=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(), federation_server=Mock(),
federation_registry=self.mock_registry, federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=["send_message"]),
"send_message",
])
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -74,9 +73,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
yield self.store.set_profile_displayname( yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
self.frank.localpart, "Frank"
)
displayname = yield self.handler.get_displayname(self.frank) displayname = yield self.handler.get_displayname(self.frank)
@ -85,22 +82,18 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
yield self.handler.set_displayname( yield self.handler.set_displayname(
self.frank, self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
synapse.types.create_requester(self.frank),
"Frank Jr."
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), (yield self.store.get_profile_displayname(self.frank.localpart)),
"Frank Jr." "Frank Jr.",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname( d = self.handler.set_displayname(
self.frank, self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
synapse.types.create_requester(self.bob),
"Frank Jr."
) )
yield self.assertFailure(d, AuthError) yield self.assertFailure(d, AuthError)
@ -145,11 +138,12 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield self.handler.set_avatar_url( yield self.handler.set_avatar_url(
self.frank, synapse.types.create_requester(self.frank), self.frank,
"http://my.server/pic.gif" synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (yield self.store.get_profile_avatar_url(self.frank.localpart)),
"http://my.server/pic.gif" "http://my.server/pic.gif",
) )

View File

@ -17,7 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import RegistrationError from synapse.api.errors import AuthError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -40,13 +40,15 @@ class RegistrationTestCase(unittest.TestCase):
self.mock_distributor.declare("registered_user") self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock() self.mock_captcha_client = Mock()
self.hs = yield setup_test_homeserver( self.hs = yield setup_test_homeserver(
self.addCleanup,
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True, expire_access_token=True,
profile_handler=Mock(), profile_handler=Mock(),
) )
self.macaroon_generator = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret')
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
@ -62,7 +64,8 @@ class RegistrationTestCase(unittest.TestCase):
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test") requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
requester, local_part, display_name) requester, local_part, display_name
)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -73,13 +76,15 @@ class RegistrationTestCase(unittest.TestCase):
yield store.register( yield store.register(
user_id=frank.to_string(), user_id=frank.to_string(),
token="jkv;g498752-43gj['eamb!-5", token="jkv;g498752-43gj['eamb!-5",
password_hash=None) password_hash=None,
)
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test") requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
requester, local_part, display_name) requester, local_part, display_name
)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -104,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -113,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -122,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")

Some files were not shown because too many files have changed in this diff Show More