mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-02-17 21:14:07 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into neilj/server_notices_on_blocking
This commit is contained in:
commit
fc5d937550
48
.circleci/config.yml
Normal file
48
.circleci/config.yml
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
version: 2
|
||||||
|
jobs:
|
||||||
|
sytestpy2:
|
||||||
|
machine: true
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run: docker pull matrixdotorg/sytest-synapsepy2
|
||||||
|
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/project/logs
|
||||||
|
destination: logs
|
||||||
|
sytestpy2postgres:
|
||||||
|
machine: true
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run: docker pull matrixdotorg/sytest-synapsepy2
|
||||||
|
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/project/logs
|
||||||
|
destination: logs
|
||||||
|
sytestpy3:
|
||||||
|
machine: true
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run: docker pull matrixdotorg/sytest-synapsepy3
|
||||||
|
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs hawkowl/sytestpy3
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/project/logs
|
||||||
|
destination: logs
|
||||||
|
sytestpy3postgres:
|
||||||
|
machine: true
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run: docker pull matrixdotorg/sytest-synapsepy3
|
||||||
|
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/project/logs
|
||||||
|
destination: logs
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
version: 2
|
||||||
|
build:
|
||||||
|
jobs:
|
||||||
|
- sytestpy2
|
||||||
|
- sytestpy2postgres
|
||||||
|
# Currently broken while the Python 3 port is incomplete
|
||||||
|
# - sytestpy3
|
||||||
|
# - sytestpy3postgres
|
@ -3,3 +3,6 @@ Dockerfile
|
|||||||
.gitignore
|
.gitignore
|
||||||
demo/etc
|
demo/etc
|
||||||
tox.ini
|
tox.ini
|
||||||
|
synctl
|
||||||
|
.git/*
|
||||||
|
.tox/*
|
||||||
|
10
.travis.yml
10
.travis.yml
@ -8,6 +8,9 @@ before_script:
|
|||||||
- git remote set-branches --add origin develop
|
- git remote set-branches --add origin develop
|
||||||
- git fetch origin develop
|
- git fetch origin develop
|
||||||
|
|
||||||
|
services:
|
||||||
|
- postgresql
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
fast_finish: true
|
fast_finish: true
|
||||||
include:
|
include:
|
||||||
@ -20,6 +23,9 @@ matrix:
|
|||||||
- python: 2.7
|
- python: 2.7
|
||||||
env: TOX_ENV=py27
|
env: TOX_ENV=py27
|
||||||
|
|
||||||
|
- python: 2.7
|
||||||
|
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
|
||||||
|
|
||||||
- python: 3.6
|
- python: 3.6
|
||||||
env: TOX_ENV=py36
|
env: TOX_ENV=py36
|
||||||
|
|
||||||
@ -29,6 +35,10 @@ matrix:
|
|||||||
- python: 3.6
|
- python: 3.6
|
||||||
env: TOX_ENV=check-newsfragment
|
env: TOX_ENV=check-newsfragment
|
||||||
|
|
||||||
|
allow_failures:
|
||||||
|
- python: 2.7
|
||||||
|
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- pip install tox
|
- pip install tox
|
||||||
|
|
||||||
|
@ -36,3 +36,4 @@ recursive-include changelog.d *
|
|||||||
prune .github
|
prune .github
|
||||||
prune demo/etc
|
prune demo/etc
|
||||||
prune docker
|
prune docker
|
||||||
|
prune .circleci
|
||||||
|
1
changelog.d/1491.feature
Normal file
1
changelog.d/1491.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add support for the SNI extension to federation TLS connections
|
1
changelog.d/3423.misc
Normal file
1
changelog.d/3423.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
The test suite now can run under PostgreSQL.
|
1
changelog.d/3653.feature
Normal file
1
changelog.d/3653.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Support more federation endpoints on workers
|
1
changelog.d/3660.misc
Normal file
1
changelog.d/3660.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Sytests can now be run inside a Docker container.
|
1
changelog.d/3661.bugfix
Normal file
1
changelog.d/3661.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix bug on deleting 3pid when using identity servers that don't support unbind API
|
1
changelog.d/3669.misc
Normal file
1
changelog.d/3669.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Update docker base image from alpine 3.7 to 3.8.
|
1
changelog.d/3676.bugfix
Normal file
1
changelog.d/3676.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make the tests pass on Twisted < 18.7.0
|
1
changelog.d/3677.bugfix
Normal file
1
changelog.d/3677.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Don’t ship recaptcha_ajax.js, use it directly from Google
|
1
changelog.d/3678.misc
Normal file
1
changelog.d/3678.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7.
|
1
changelog.d/3679.misc
Normal file
1
changelog.d/3679.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Synapse's tests are now formatted with the black autoformatter.
|
1
changelog.d/3681.bugfix
Normal file
1
changelog.d/3681.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fixes test_reap_monthly_active_users so it passes under postgres
|
1
changelog.d/3684.misc
Normal file
1
changelog.d/3684.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Implemented a new testing base class to reduce test boilerplate.
|
1
changelog.d/3687.feature
Normal file
1
changelog.d/3687.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
set admin uri via config, to be used in error messages where the user should contact the administrator
|
1
changelog.d/3690.misc
Normal file
1
changelog.d/3690.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Rename MAU prometheus metrics
|
1
changelog.d/3692.bugfix
Normal file
1
changelog.d/3692.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users
|
@ -1,4 +1,4 @@
|
|||||||
FROM docker.io/python:2-alpine3.7
|
FROM docker.io/python:2-alpine3.8
|
||||||
|
|
||||||
RUN apk add --no-cache --virtual .nacl_deps \
|
RUN apk add --no-cache --virtual .nacl_deps \
|
||||||
build-base \
|
build-base \
|
||||||
|
@ -173,10 +173,23 @@ endpoints matching the following regular expressions::
|
|||||||
^/_matrix/federation/v1/backfill/
|
^/_matrix/federation/v1/backfill/
|
||||||
^/_matrix/federation/v1/get_missing_events/
|
^/_matrix/federation/v1/get_missing_events/
|
||||||
^/_matrix/federation/v1/publicRooms
|
^/_matrix/federation/v1/publicRooms
|
||||||
|
^/_matrix/federation/v1/query/
|
||||||
|
^/_matrix/federation/v1/make_join/
|
||||||
|
^/_matrix/federation/v1/make_leave/
|
||||||
|
^/_matrix/federation/v1/send_join/
|
||||||
|
^/_matrix/federation/v1/send_leave/
|
||||||
|
^/_matrix/federation/v1/invite/
|
||||||
|
^/_matrix/federation/v1/query_auth/
|
||||||
|
^/_matrix/federation/v1/event_auth/
|
||||||
|
^/_matrix/federation/v1/exchange_third_party_invite/
|
||||||
|
^/_matrix/federation/v1/send/
|
||||||
|
|
||||||
The above endpoints should all be routed to the federation_reader worker by the
|
The above endpoints should all be routed to the federation_reader worker by the
|
||||||
reverse-proxy configuration.
|
reverse-proxy configuration.
|
||||||
|
|
||||||
|
The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single
|
||||||
|
instance.
|
||||||
|
|
||||||
``synapse.app.federation_sender``
|
``synapse.app.federation_sender``
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -780,11 +780,14 @@ class Auth(object):
|
|||||||
such as monthly active user limiting or global disable flag
|
such as monthly active user limiting or global disable flag
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): If present, checks for presence against existing MAU cohort
|
user_id(str|None): If present, checks for presence against existing
|
||||||
|
MAU cohort
|
||||||
"""
|
"""
|
||||||
if self.hs.config.hs_disabled:
|
if self.hs.config.hs_disabled:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
|
403, self.hs.config.hs_disabled_message,
|
||||||
|
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||||
|
admin_uri=self.hs.config.admin_uri,
|
||||||
)
|
)
|
||||||
if self.hs.config.limit_usage_by_mau is True:
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
# If the user is already part of the MAU cohort
|
# If the user is already part of the MAU cohort
|
||||||
@ -796,5 +799,7 @@ class Auth(object):
|
|||||||
current_mau = yield self.store.get_monthly_active_count()
|
current_mau = yield self.store.get_monthly_active_count()
|
||||||
if current_mau >= self.hs.config.max_mau_value:
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
403, "Monthly Active User Limits AU Limit Exceeded",
|
||||||
|
admin_uri=self.hs.config.admin_uri,
|
||||||
|
errcode=Codes.RESOURCE_LIMIT_EXCEED
|
||||||
)
|
)
|
||||||
|
@ -56,8 +56,7 @@ class Codes(object):
|
|||||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
||||||
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
||||||
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
|
RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED"
|
||||||
HS_DISABLED = "M_HS_DISABLED"
|
|
||||||
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
|
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
|
||||||
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
||||||
|
|
||||||
@ -225,11 +224,16 @@ class NotFoundError(SynapseError):
|
|||||||
|
|
||||||
class AuthError(SynapseError):
|
class AuthError(SynapseError):
|
||||||
"""An error raised when there was a problem authorising an event."""
|
"""An error raised when there was a problem authorising an event."""
|
||||||
|
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None):
|
||||||
|
self.admin_uri = admin_uri
|
||||||
|
super(AuthError, self).__init__(code, msg, errcode=errcode)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def error_dict(self):
|
||||||
if "errcode" not in kwargs:
|
return cs_error(
|
||||||
kwargs["errcode"] = Codes.FORBIDDEN
|
self.msg,
|
||||||
super(AuthError, self).__init__(*args, **kwargs)
|
self.errcode,
|
||||||
|
admin_uri=self.admin_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EventSizeError(SynapseError):
|
class EventSizeError(SynapseError):
|
||||||
|
@ -39,7 +39,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
|
|||||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.client.v1.room import (
|
from synapse.rest.client.v1.room import (
|
||||||
JoinedRoomMemberListRestServlet,
|
JoinedRoomMemberListRestServlet,
|
||||||
@ -66,7 +66,7 @@ class ClientReaderSlavedStore(
|
|||||||
DirectoryStore,
|
DirectoryStore,
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
TransactionStore,
|
SlavedTransactionStore,
|
||||||
SlavedClientIpStore,
|
SlavedClientIpStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
):
|
):
|
||||||
@ -168,11 +168,13 @@ def start(config_options):
|
|||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ss = ClientReaderServer(
|
ss = ClientReaderServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -43,7 +43,7 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
|||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.client.v1.room import (
|
from synapse.rest.client.v1.room import (
|
||||||
JoinRoomAliasServlet,
|
JoinRoomAliasServlet,
|
||||||
@ -63,7 +63,7 @@ logger = logging.getLogger("synapse.app.event_creator")
|
|||||||
|
|
||||||
class EventCreatorSlavedStore(
|
class EventCreatorSlavedStore(
|
||||||
DirectoryStore,
|
DirectoryStore,
|
||||||
TransactionStore,
|
SlavedTransactionStore,
|
||||||
SlavedProfileStore,
|
SlavedProfileStore,
|
||||||
SlavedAccountDataStore,
|
SlavedAccountDataStore,
|
||||||
SlavedPusherStore,
|
SlavedPusherStore,
|
||||||
@ -174,11 +174,13 @@ def start(config_options):
|
|||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ss = EventCreatorServer(
|
ss = EventCreatorServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -32,11 +32,16 @@ from synapse.http.site import SynapseSite
|
|||||||
from synapse.metrics import RegistryProxy
|
from synapse.metrics import RegistryProxy
|
||||||
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||||
|
from synapse.replication.slave.storage.profile import SlavedProfileStore
|
||||||
|
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||||
|
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||||
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
@ -49,11 +54,16 @@ logger = logging.getLogger("synapse.app.federation_reader")
|
|||||||
|
|
||||||
|
|
||||||
class FederationReaderSlavedStore(
|
class FederationReaderSlavedStore(
|
||||||
|
SlavedProfileStore,
|
||||||
|
SlavedApplicationServiceStore,
|
||||||
|
SlavedPusherStore,
|
||||||
|
SlavedPushRuleStore,
|
||||||
|
SlavedReceiptsStore,
|
||||||
SlavedEventStore,
|
SlavedEventStore,
|
||||||
SlavedKeyStore,
|
SlavedKeyStore,
|
||||||
RoomStore,
|
RoomStore,
|
||||||
DirectoryStore,
|
DirectoryStore,
|
||||||
TransactionStore,
|
SlavedTransactionStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@ -143,11 +153,13 @@ def start(config_options):
|
|||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ss = FederationReaderServer(
|
ss = FederationReaderServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -36,11 +36,11 @@ from synapse.replication.slave.storage.events import SlavedEventStore
|
|||||||
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, run_in_background
|
from synapse.util.logcontext import LoggingContext, run_in_background
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
@ -50,7 +50,7 @@ logger = logging.getLogger("synapse.app.federation_sender")
|
|||||||
|
|
||||||
|
|
||||||
class FederationSenderSlaveStore(
|
class FederationSenderSlaveStore(
|
||||||
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||||
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
|
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
|
||||||
):
|
):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
@ -186,11 +186,13 @@ def start(config_options):
|
|||||||
config.send_federation = True
|
config.send_federation = True
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ps = FederationSenderServer(
|
ps = FederationSenderServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -208,11 +208,13 @@ def start(config_options):
|
|||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ss = FrontendProxyServer(
|
ss = FrontendProxyServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -303,8 +303,8 @@ class SynapseHomeServer(HomeServer):
|
|||||||
|
|
||||||
|
|
||||||
# Gauges to expose monthly active user control metrics
|
# Gauges to expose monthly active user control metrics
|
||||||
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
|
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
||||||
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
|
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
|
||||||
|
|
||||||
|
|
||||||
def setup(config_options):
|
def setup(config_options):
|
||||||
@ -338,6 +338,7 @@ def setup(config_options):
|
|||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||||
@ -346,6 +347,7 @@ def setup(config_options):
|
|||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
@ -530,7 +532,7 @@ def run(hs):
|
|||||||
if hs.config.limit_usage_by_mau:
|
if hs.config.limit_usage_by_mau:
|
||||||
count = yield hs.get_datastore().get_monthly_active_count()
|
count = yield hs.get_datastore().get_monthly_active_count()
|
||||||
current_mau_gauge.set(float(count))
|
current_mau_gauge.set(float(count))
|
||||||
max_mau_value_gauge.set(float(hs.config.max_mau_value))
|
max_mau_gauge.set(float(hs.config.max_mau_value))
|
||||||
|
|
||||||
hs.get_datastore().initialise_reserved_users(
|
hs.get_datastore().initialise_reserved_users(
|
||||||
hs.config.mau_limits_reserved_threepids
|
hs.config.mau_limits_reserved_threepids
|
||||||
|
@ -34,7 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
|
|||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -52,7 +52,7 @@ class MediaRepositorySlavedStore(
|
|||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
SlavedClientIpStore,
|
SlavedClientIpStore,
|
||||||
TransactionStore,
|
SlavedTransactionStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
MediaRepositoryStore,
|
MediaRepositoryStore,
|
||||||
):
|
):
|
||||||
@ -155,11 +155,13 @@ def start(config_options):
|
|||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ss = MediaRepositoryServer(
|
ss = MediaRepositoryServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -214,11 +214,13 @@ def start(config_options):
|
|||||||
config.update_user_directory = True
|
config.update_user_directory = True
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||||
|
|
||||||
ps = UserDirectoryServer(
|
ps = UserDirectoryServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
tls_client_options_factory=tls_client_options_factory,
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -193,9 +193,8 @@ def setup_logging(config, use_worker_options=False):
|
|||||||
|
|
||||||
def sighup(signum, stack):
|
def sighup(signum, stack):
|
||||||
# it might be better to use a file watcher or something for this.
|
# it might be better to use a file watcher or something for this.
|
||||||
logging.info("Reloading log config from %s due to SIGHUP",
|
|
||||||
log_config)
|
|
||||||
load_log_config()
|
load_log_config()
|
||||||
|
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
|
||||||
|
|
||||||
load_log_config()
|
load_log_config()
|
||||||
|
|
||||||
|
@ -82,6 +82,10 @@ class ServerConfig(Config):
|
|||||||
self.hs_disabled = config.get("hs_disabled", False)
|
self.hs_disabled = config.get("hs_disabled", False)
|
||||||
self.hs_disabled_message = config.get("hs_disabled_message", "")
|
self.hs_disabled_message = config.get("hs_disabled_message", "")
|
||||||
|
|
||||||
|
# Admin uri to direct users at should their instance become blocked
|
||||||
|
# due to resource constraints
|
||||||
|
self.admin_uri = config.get("admin_uri", None)
|
||||||
|
|
||||||
# FIXME: federation_domain_whitelist needs sytests
|
# FIXME: federation_domain_whitelist needs sytests
|
||||||
self.federation_domain_whitelist = None
|
self.federation_domain_whitelist = None
|
||||||
federation_domain_whitelist = config.get(
|
federation_domain_whitelist = config.get(
|
||||||
|
@ -11,19 +11,22 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
from OpenSSL import SSL, crypto
|
from OpenSSL import SSL, crypto
|
||||||
from twisted.internet import ssl
|
|
||||||
from twisted.internet._sslverify import _defaultCurveName
|
from twisted.internet._sslverify import _defaultCurveName
|
||||||
|
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
||||||
|
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ServerContextFactory(ssl.ContextFactory):
|
class ServerContextFactory(ContextFactory):
|
||||||
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
|
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
|
||||||
connections and to make connections to remote servers."""
|
connections."""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory):
|
|||||||
|
|
||||||
def getContext(self):
|
def getContext(self):
|
||||||
return self._context
|
return self._context
|
||||||
|
|
||||||
|
|
||||||
|
def _idnaBytes(text):
|
||||||
|
"""
|
||||||
|
Convert some text typed by a human into some ASCII bytes. This is a
|
||||||
|
copy of twisted.internet._idna._idnaBytes. For documentation, see the
|
||||||
|
twisted documentation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import idna
|
||||||
|
except ImportError:
|
||||||
|
return text.encode("idna")
|
||||||
|
else:
|
||||||
|
return idna.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _tolerateErrors(wrapped):
|
||||||
|
"""
|
||||||
|
Wrap up an info_callback for pyOpenSSL so that if something goes wrong
|
||||||
|
the error is immediately logged and the connection is dropped if possible.
|
||||||
|
This is a copy of twisted.internet._sslverify._tolerateErrors. For
|
||||||
|
documentation, see the twisted documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def infoCallback(connection, where, ret):
|
||||||
|
try:
|
||||||
|
return wrapped(connection, where, ret)
|
||||||
|
except: # noqa: E722, taken from the twisted implementation
|
||||||
|
f = Failure()
|
||||||
|
logger.exception("Error during info_callback")
|
||||||
|
connection.get_app_data().failVerification(f)
|
||||||
|
|
||||||
|
return infoCallback
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IOpenSSLClientConnectionCreator)
|
||||||
|
class ClientTLSOptions(object):
|
||||||
|
"""
|
||||||
|
Client creator for TLS without certificate identity verification. This is a
|
||||||
|
copy of twisted.internet._sslverify.ClientTLSOptions with the identity
|
||||||
|
verification left out. For documentation, see the twisted documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hostname, ctx):
|
||||||
|
self._ctx = ctx
|
||||||
|
self._hostname = hostname
|
||||||
|
self._hostnameBytes = _idnaBytes(hostname)
|
||||||
|
ctx.set_info_callback(
|
||||||
|
_tolerateErrors(self._identityVerifyingInfoCallback)
|
||||||
|
)
|
||||||
|
|
||||||
|
def clientConnectionForTLS(self, tlsProtocol):
|
||||||
|
context = self._ctx
|
||||||
|
connection = SSL.Connection(context, None)
|
||||||
|
connection.set_app_data(tlsProtocol)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
||||||
|
if where & SSL.SSL_CB_HANDSHAKE_START:
|
||||||
|
connection.set_tlsext_host_name(self._hostnameBytes)
|
||||||
|
|
||||||
|
|
||||||
|
class ClientTLSOptionsFactory(object):
|
||||||
|
"""Factory for Twisted ClientTLSOptions that are used to make connections
|
||||||
|
to remote servers for federation."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
# We don't use config options yet
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_options(self, host):
|
||||||
|
return ClientTLSOptions(
|
||||||
|
host.decode('utf-8'),
|
||||||
|
CertificateOptions(verify=False).getContext()
|
||||||
|
)
|
||||||
|
@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/"
|
|||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
|
||||||
"""Fetch the keys for a remote server."""
|
"""Fetch the keys for a remote server."""
|
||||||
|
|
||||||
factory = SynapseKeyClientFactory()
|
factory = SynapseKeyClientFactory()
|
||||||
factory.path = path
|
factory.path = path
|
||||||
factory.host = server_name
|
factory.host = server_name
|
||||||
endpoint = matrix_federation_endpoint(
|
endpoint = matrix_federation_endpoint(
|
||||||
reactor, server_name, ssl_context_factory, timeout=30
|
reactor, server_name, tls_client_options_factory, timeout=30
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
|
@ -512,7 +512,7 @@ class Keyring(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
(response, tls_certificate) = yield fetch_server_key(
|
(response, tls_certificate) = yield fetch_server_key(
|
||||||
server_name, self.hs.tls_server_context_factory,
|
server_name, self.hs.tls_client_options_factory,
|
||||||
path=(b"/_matrix/key/v2/server/%s" % (
|
path=(b"/_matrix/key/v2/server/%s" % (
|
||||||
urllib.quote(requested_key_id),
|
urllib.quote(requested_key_id),
|
||||||
)).encode("ascii"),
|
)).encode("ascii"),
|
||||||
@ -655,7 +655,7 @@ class Keyring(object):
|
|||||||
# Try to fetch the key from the remote server.
|
# Try to fetch the key from the remote server.
|
||||||
|
|
||||||
(response, tls_certificate) = yield fetch_server_key(
|
(response, tls_certificate) = yield fetch_server_key(
|
||||||
server_name, self.hs.tls_server_context_factory
|
server_name, self.hs.tls_client_options_factory
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check the response.
|
# Check the response.
|
||||||
|
@ -39,8 +39,12 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
|
|||||||
from synapse.federation.persistence import TransactionActions
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
from synapse.http.endpoint import parse_server_name
|
from synapse.http.endpoint import parse_server_name
|
||||||
|
from synapse.replication.http.federation import (
|
||||||
|
ReplicationFederationSendEduRestServlet,
|
||||||
|
ReplicationGetQueryRestServlet,
|
||||||
|
)
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util import async
|
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
|
||||||
@ -67,8 +71,8 @@ class FederationServer(FederationBase):
|
|||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.handler = hs.get_handlers().federation_handler
|
self.handler = hs.get_handlers().federation_handler
|
||||||
|
|
||||||
self._server_linearizer = async.Linearizer("fed_server")
|
self._server_linearizer = Linearizer("fed_server")
|
||||||
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
self._transaction_linearizer = Linearizer("fed_txn_handler")
|
||||||
|
|
||||||
self.transaction_actions = TransactionActions(self.store)
|
self.transaction_actions = TransactionActions(self.store)
|
||||||
|
|
||||||
@ -200,7 +204,7 @@ class FederationServer(FederationBase):
|
|||||||
event_id, f.getTraceback().rstrip(),
|
event_id, f.getTraceback().rstrip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield async.concurrently_execute(
|
yield concurrently_execute(
|
||||||
process_pdus_for_room, pdus_by_room.keys(),
|
process_pdus_for_room, pdus_by_room.keys(),
|
||||||
TRANSACTION_CONCURRENCY_LIMIT,
|
TRANSACTION_CONCURRENCY_LIMIT,
|
||||||
)
|
)
|
||||||
@ -760,6 +764,8 @@ class FederationHandlerRegistry(object):
|
|||||||
if edu_type in self.edu_handlers:
|
if edu_type in self.edu_handlers:
|
||||||
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
|
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
|
||||||
|
|
||||||
|
logger.info("Registering federation EDU handler for %r", edu_type)
|
||||||
|
|
||||||
self.edu_handlers[edu_type] = handler
|
self.edu_handlers[edu_type] = handler
|
||||||
|
|
||||||
def register_query_handler(self, query_type, handler):
|
def register_query_handler(self, query_type, handler):
|
||||||
@ -778,6 +784,8 @@ class FederationHandlerRegistry(object):
|
|||||||
"Already have a Query handler for %s" % (query_type,)
|
"Already have a Query handler for %s" % (query_type,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Registering federation query handler for %r", query_type)
|
||||||
|
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -800,3 +808,49 @@ class FederationHandlerRegistry(object):
|
|||||||
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
|
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
|
||||||
|
|
||||||
return handler(args)
|
return handler(args)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||||
|
"""A FederationHandlerRegistry for worker processes.
|
||||||
|
|
||||||
|
When receiving EDU or queries it will check if an appropriate handler has
|
||||||
|
been registered on the worker, if there isn't one then it calls off to the
|
||||||
|
master process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.config = hs.config
|
||||||
|
self.http_client = hs.get_simple_http_client()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
|
||||||
|
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
|
||||||
|
|
||||||
|
super(ReplicationFederationHandlerRegistry, self).__init__()
|
||||||
|
|
||||||
|
def on_edu(self, edu_type, origin, content):
|
||||||
|
"""Overrides FederationHandlerRegistry
|
||||||
|
"""
|
||||||
|
handler = self.edu_handlers.get(edu_type)
|
||||||
|
if handler:
|
||||||
|
return super(ReplicationFederationHandlerRegistry, self).on_edu(
|
||||||
|
edu_type, origin, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._send_edu(
|
||||||
|
edu_type=edu_type,
|
||||||
|
origin=origin,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_query(self, query_type, args):
|
||||||
|
"""Overrides FederationHandlerRegistry
|
||||||
|
"""
|
||||||
|
handler = self.query_handlers.get(query_type)
|
||||||
|
if handler:
|
||||||
|
return handler(args)
|
||||||
|
|
||||||
|
return self._get_query_client(
|
||||||
|
query_type=query_type,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
@ -828,12 +828,26 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_threepid(self, user_id, medium, address):
|
def delete_threepid(self, user_id, medium, address):
|
||||||
|
"""Attempts to unbind the 3pid on the identity servers and deletes it
|
||||||
|
from the local database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str)
|
||||||
|
medium (str)
|
||||||
|
address (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[bool]: Returns True if successfully unbound the 3pid on
|
||||||
|
the identity server, False if identity server doesn't support the
|
||||||
|
unbind API.
|
||||||
|
"""
|
||||||
|
|
||||||
# 'Canonicalise' email addresses as per above
|
# 'Canonicalise' email addresses as per above
|
||||||
if medium == 'email':
|
if medium == 'email':
|
||||||
address = address.lower()
|
address = address.lower()
|
||||||
|
|
||||||
identity_handler = self.hs.get_handlers().identity_handler
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
yield identity_handler.unbind_threepid(
|
result = yield identity_handler.try_unbind_threepid(
|
||||||
user_id,
|
user_id,
|
||||||
{
|
{
|
||||||
'medium': medium,
|
'medium': medium,
|
||||||
@ -841,10 +855,10 @@ class AuthHandler(BaseHandler):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = yield self.store.user_delete_threepid(
|
yield self.store.user_delete_threepid(
|
||||||
user_id, medium, address,
|
user_id, medium, address,
|
||||||
)
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _save_session(self, session):
|
def _save_session(self, session):
|
||||||
# TODO: Persistent storage
|
# TODO: Persistent storage
|
||||||
|
@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
erase_data (bool): whether to GDPR-erase the user's data
|
erase_data (bool): whether to GDPR-erase the user's data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred[bool]: True if identity server supports removing
|
||||||
|
threepids, otherwise False.
|
||||||
"""
|
"""
|
||||||
# FIXME: Theoretically there is a race here wherein user resets
|
# FIXME: Theoretically there is a race here wherein user resets
|
||||||
# password using threepid.
|
# password using threepid.
|
||||||
@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
# leave the user still active so they can try again.
|
# leave the user still active so they can try again.
|
||||||
# Ideally we would prevent password resets and then do this in the
|
# Ideally we would prevent password resets and then do this in the
|
||||||
# background thread.
|
# background thread.
|
||||||
|
|
||||||
|
# This will be set to false if the identity server doesn't support
|
||||||
|
# unbinding
|
||||||
|
identity_server_supports_unbinding = True
|
||||||
|
|
||||||
threepids = yield self.store.user_get_threepids(user_id)
|
threepids = yield self.store.user_get_threepids(user_id)
|
||||||
for threepid in threepids:
|
for threepid in threepids:
|
||||||
try:
|
try:
|
||||||
yield self._identity_handler.unbind_threepid(
|
result = yield self._identity_handler.try_unbind_threepid(
|
||||||
user_id,
|
user_id,
|
||||||
{
|
{
|
||||||
'medium': threepid['medium'],
|
'medium': threepid['medium'],
|
||||||
'address': threepid['address'],
|
'address': threepid['address'],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
identity_server_supports_unbinding &= result
|
||||||
except Exception:
|
except Exception:
|
||||||
# Do we want this to be a fatal error or should we carry on?
|
# Do we want this to be a fatal error or should we carry on?
|
||||||
logger.exception("Failed to remove threepid from ID server")
|
logger.exception("Failed to remove threepid from ID server")
|
||||||
@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
# parts users from rooms (if it isn't already running)
|
# parts users from rooms (if it isn't already running)
|
||||||
self._start_user_parting()
|
self._start_user_parting()
|
||||||
|
|
||||||
|
defer.returnValue(identity_server_supports_unbinding)
|
||||||
|
|
||||||
def _start_user_parting(self):
|
def _start_user_parting(self):
|
||||||
"""
|
"""
|
||||||
Start the process that goes through the table of users
|
Start the process that goes through the table of users
|
||||||
|
@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes
|
|||||||
from synapse.api.errors import FederationDeniedError
|
from synapse.api.errors import FederationDeniedError
|
||||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
@ -49,10 +49,15 @@ from synapse.crypto.event_signing import (
|
|||||||
compute_event_signature,
|
compute_event_signature,
|
||||||
)
|
)
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
from synapse.replication.http.federation import (
|
||||||
|
ReplicationCleanRoomRestServlet,
|
||||||
|
ReplicationFederationSendEventsRestServlet,
|
||||||
|
)
|
||||||
|
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
|
||||||
from synapse.state import resolve_events_with_factory
|
from synapse.state import resolve_events_with_factory
|
||||||
from synapse.types import UserID, get_domain_from_id
|
from synapse.types import UserID, get_domain_from_id
|
||||||
from synapse.util import logcontext, unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.distributor import user_joined_room
|
from synapse.util.distributor import user_joined_room
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
@ -91,6 +96,18 @@ class FederationHandler(BaseHandler):
|
|||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||||
|
self.config = hs.config
|
||||||
|
self.http_client = hs.get_simple_http_client()
|
||||||
|
|
||||||
|
self._send_events_to_master = (
|
||||||
|
ReplicationFederationSendEventsRestServlet.make_client(hs)
|
||||||
|
)
|
||||||
|
self._notify_user_membership_change = (
|
||||||
|
ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
|
||||||
|
)
|
||||||
|
self._clean_room_for_join_client = (
|
||||||
|
ReplicationCleanRoomRestServlet.make_client(hs)
|
||||||
|
)
|
||||||
|
|
||||||
# When joining a room we need to queue any events for that room up
|
# When joining a room we need to queue any events for that room up
|
||||||
self.room_queues = {}
|
self.room_queues = {}
|
||||||
@ -1158,7 +1175,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(event)
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
yield self._persist_events([(event, context)])
|
yield self.persist_events_and_notify([(event, context)])
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
@ -1189,7 +1206,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(event)
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
yield self._persist_events([(event, context)])
|
yield self.persist_events_and_notify([(event, context)])
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
@ -1432,7 +1449,7 @@ class FederationHandler(BaseHandler):
|
|||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._persist_events(
|
yield self.persist_events_and_notify(
|
||||||
[(event, context)],
|
[(event, context)],
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
@ -1470,7 +1487,7 @@ class FederationHandler(BaseHandler):
|
|||||||
], consumeErrors=True,
|
], consumeErrors=True,
|
||||||
))
|
))
|
||||||
|
|
||||||
yield self._persist_events(
|
yield self.persist_events_and_notify(
|
||||||
[
|
[
|
||||||
(ev_info["event"], context)
|
(ev_info["event"], context)
|
||||||
for ev_info, context in zip(event_infos, contexts)
|
for ev_info, context in zip(event_infos, contexts)
|
||||||
@ -1558,7 +1575,7 @@ class FederationHandler(BaseHandler):
|
|||||||
raise
|
raise
|
||||||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
yield self._persist_events(
|
yield self.persist_events_and_notify(
|
||||||
[
|
[
|
||||||
(e, events_to_context[e.event_id])
|
(e, events_to_context[e.event_id])
|
||||||
for e in itertools.chain(auth_events, state)
|
for e in itertools.chain(auth_events, state)
|
||||||
@ -1569,7 +1586,7 @@ class FederationHandler(BaseHandler):
|
|||||||
event, old_state=state
|
event, old_state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._persist_events(
|
yield self.persist_events_and_notify(
|
||||||
[(event, new_event_context)],
|
[(event, new_event_context)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2297,7 +2314,7 @@ class FederationHandler(BaseHandler):
|
|||||||
for revocation.
|
for revocation.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = yield self.hs.get_simple_http_client().get_json(
|
response = yield self.http_client.get_json(
|
||||||
url,
|
url,
|
||||||
{"public_key": public_key}
|
{"public_key": public_key}
|
||||||
)
|
)
|
||||||
@ -2310,7 +2327,7 @@ class FederationHandler(BaseHandler):
|
|||||||
raise AuthError(403, "Third party certificate was invalid")
|
raise AuthError(403, "Third party certificate was invalid")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _persist_events(self, event_and_contexts, backfilled=False):
|
def persist_events_and_notify(self, event_and_contexts, backfilled=False):
|
||||||
"""Persists events and tells the notifier/pushers about them, if
|
"""Persists events and tells the notifier/pushers about them, if
|
||||||
necessary.
|
necessary.
|
||||||
|
|
||||||
@ -2322,14 +2339,21 @@ class FederationHandler(BaseHandler):
|
|||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
"""
|
"""
|
||||||
max_stream_id = yield self.store.persist_events(
|
if self.config.worker_app:
|
||||||
event_and_contexts,
|
yield self._send_events_to_master(
|
||||||
backfilled=backfilled,
|
store=self.store,
|
||||||
)
|
event_and_contexts=event_and_contexts,
|
||||||
|
backfilled=backfilled
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_stream_id = yield self.store.persist_events(
|
||||||
|
event_and_contexts,
|
||||||
|
backfilled=backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
if not backfilled: # Never notify for backfilled events
|
if not backfilled: # Never notify for backfilled events
|
||||||
for event, _ in event_and_contexts:
|
for event, _ in event_and_contexts:
|
||||||
self._notify_persisted_event(event, max_stream_id)
|
self._notify_persisted_event(event, max_stream_id)
|
||||||
|
|
||||||
def _notify_persisted_event(self, event, max_stream_id):
|
def _notify_persisted_event(self, event, max_stream_id):
|
||||||
"""Checks to see if notifier/pushers should be notified about the
|
"""Checks to see if notifier/pushers should be notified about the
|
||||||
@ -2368,9 +2392,25 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _clean_room_for_join(self, room_id):
|
def _clean_room_for_join(self, room_id):
|
||||||
return self.store.clean_room_for_join(room_id)
|
"""Called to clean up any data in DB for a given room, ready for the
|
||||||
|
server to join the room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
"""
|
||||||
|
if self.config.worker_app:
|
||||||
|
return self._clean_room_for_join_client(room_id)
|
||||||
|
else:
|
||||||
|
return self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
def user_joined_room(self, user, room_id):
|
def user_joined_room(self, user, room_id):
|
||||||
"""Called when a new user has joined the room
|
"""Called when a new user has joined the room
|
||||||
"""
|
"""
|
||||||
return user_joined_room(self.distributor, user, room_id)
|
if self.config.worker_app:
|
||||||
|
return self._notify_user_membership_change(
|
||||||
|
room_id=room_id,
|
||||||
|
user_id=user.to_string(),
|
||||||
|
change="joined",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return user_joined_room(self.distributor, user, room_id)
|
||||||
|
@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler):
|
|||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def unbind_threepid(self, mxid, threepid):
|
def try_unbind_threepid(self, mxid, threepid):
|
||||||
"""
|
"""Removes a binding from an identity server
|
||||||
Removes a binding from an identity server
|
|
||||||
Args:
|
Args:
|
||||||
mxid (str): Matrix user ID of binding to be removed
|
mxid (str): Matrix user ID of binding to be removed
|
||||||
threepid (dict): Dict with medium & address of binding to be removed
|
threepid (dict): Dict with medium & address of binding to be removed
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError: If we failed to contact the identity server
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[bool]: True on success, otherwise False
|
Deferred[bool]: True on success, otherwise False if the identity
|
||||||
|
server doesn't support unbinding
|
||||||
"""
|
"""
|
||||||
logger.debug("unbinding threepid %r from %s", threepid, mxid)
|
logger.debug("unbinding threepid %r from %s", threepid, mxid)
|
||||||
if not self.trusted_id_servers:
|
if not self.trusted_id_servers:
|
||||||
@ -175,11 +179,21 @@ class IdentityHandler(BaseHandler):
|
|||||||
content=content,
|
content=content,
|
||||||
destination_is=id_server,
|
destination_is=id_server,
|
||||||
)
|
)
|
||||||
yield self.http_client.post_json_get_json(
|
try:
|
||||||
url,
|
yield self.http_client.post_json_get_json(
|
||||||
content,
|
url,
|
||||||
headers,
|
content,
|
||||||
)
|
headers,
|
||||||
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
if e.code in (400, 404, 501,):
|
||||||
|
# The remote server probably doesn't support unbinding (yet)
|
||||||
|
logger.warn("Received %d response while unbinding threepid", e.code)
|
||||||
|
defer.returnValue(False)
|
||||||
|
else:
|
||||||
|
logger.error("Failed to unbind threepid on identity server: %s", e)
|
||||||
|
raise SynapseError(502, "Failed to contact identity server")
|
||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
|
|||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import StreamToken, UserID
|
from synapse.types import StreamToken, UserID
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event
|
|||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||||
from synapse.types import RoomAlias, UserID
|
from synapse.types import RoomAlias, UserID
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
from synapse.util.frozenutils import frozendict_json_encoder
|
||||||
from synapse.util.logcontext import run_in_background
|
from synapse.util.logcontext import run_in_background
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
@ -22,7 +22,7 @@ from synapse.api.constants import Membership
|
|||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.async import ReadWriteLock
|
from synapse.util.async_helpers import ReadWriteLock
|
||||||
from synapse.util.logcontext import run_in_background
|
from synapse.util.logcontext import run_in_background
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError
|
|||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.types import UserID, get_domain_from_id
|
from synapse.types import UserID, get_domain_from_id
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.util.logcontext import run_in_background
|
from synapse.util.logcontext import run_in_background
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
@ -95,6 +95,7 @@ class PresenceHandler(object):
|
|||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer):
|
hs (synapse.server.HomeServer):
|
||||||
"""
|
"""
|
||||||
|
self.hs = hs
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
@ -230,6 +231,10 @@ class PresenceHandler(object):
|
|||||||
earlier than they should when synapse is restarted. This affect of this
|
earlier than they should when synapse is restarted. This affect of this
|
||||||
is some spurious presence changes that will self-correct.
|
is some spurious presence changes that will self-correct.
|
||||||
"""
|
"""
|
||||||
|
# If the DB pool has already terminated, don't try updating
|
||||||
|
if not self.hs.get_db_pool().running:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Performing _on_shutdown. Persisting %d unpersisted changes",
|
"Performing _on_shutdown. Persisting %d unpersisted changes",
|
||||||
len(self.user_to_current_state)
|
len(self.user_to_current_state)
|
||||||
|
@ -17,7 +17,7 @@ import logging
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from synapse.api.errors import (
|
|||||||
)
|
)
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.threepids import check_3pid_allowed
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
Raises:
|
Raises:
|
||||||
RegistrationError if there was a problem registering.
|
RegistrationError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
yield self._check_mau_limits()
|
|
||||||
|
yield self.auth.check_auth_blocking()
|
||||||
password_hash = None
|
password_hash = None
|
||||||
if password:
|
if password:
|
||||||
password_hash = yield self.auth_handler().hash(password)
|
password_hash = yield self.auth_handler().hash(password)
|
||||||
@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
400,
|
400,
|
||||||
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
||||||
)
|
)
|
||||||
yield self._check_mau_limits()
|
yield self.auth.check_auth_blocking()
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
if localpart is None:
|
if localpart is None:
|
||||||
raise SynapseError(400, "Request must include user id")
|
raise SynapseError(400, "Request must include user id")
|
||||||
yield self._check_mau_limits()
|
yield self.auth.check_auth_blocking()
|
||||||
need_register = True
|
need_register = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -533,14 +534,3 @@ class RegistrationHandler(BaseHandler):
|
|||||||
remote_room_hosts=remote_room_hosts,
|
remote_room_hosts=remote_room_hosts,
|
||||||
action="join",
|
action="join",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_mau_limits(self):
|
|
||||||
"""
|
|
||||||
Do not accept registrations if monthly active user limits exceeded
|
|
||||||
and limiting is enabled
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
yield self.auth.check_auth_blocking()
|
|
||||||
except AuthError as e:
|
|
||||||
raise RegistrationError(e.code, str(e), e.errcode)
|
|
||||||
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.api.constants import EventTypes, JoinRules
|
from synapse.api.constants import EventTypes, JoinRules
|
||||||
from synapse.types import ThirdPartyInstanceID
|
from synapse.types import ThirdPartyInstanceID
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ import synapse.types
|
|||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||||
from synapse.types import RoomID, UserID
|
from synapse.types import RoomID, UserID
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.distributor import user_joined_room, user_left_room
|
from synapse.util.distributor import user_joined_room, user_left_room
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers
|
|||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
|
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
|
||||||
from synapse.http.endpoint import SpiderEndpoint
|
from synapse.http.endpoint import SpiderEndpoint
|
||||||
from synapse.util.async import add_timeout_to_deferred
|
from synapse.util.async_helpers import add_timeout_to_deferred
|
||||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SERVER_CACHE = {}
|
SERVER_CACHE = {}
|
||||||
|
|
||||||
# our record of an individual server which can be tried to reach a destination.
|
# our record of an individual server which can be tried to reach a destination.
|
||||||
@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name):
|
|||||||
return host, port
|
return host, port
|
||||||
|
|
||||||
|
|
||||||
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
|
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
|
||||||
timeout=None):
|
timeout=None):
|
||||||
"""Construct an endpoint for the given matrix destination.
|
"""Construct an endpoint for the given matrix destination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reactor: Twisted reactor.
|
reactor: Twisted reactor.
|
||||||
destination (bytes): The name of the server to connect to.
|
destination (bytes): The name of the server to connect to.
|
||||||
ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
|
tls_client_options_factory
|
||||||
which generates SSL contexts to use for TLS.
|
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
|
||||||
|
Factory which generates TLS options for client connections.
|
||||||
timeout (int): connection timeout in seconds
|
timeout (int): connection timeout in seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
|
|||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
endpoint_kw_args.update(timeout=timeout)
|
endpoint_kw_args.update(timeout=timeout)
|
||||||
|
|
||||||
if ssl_context_factory is None:
|
if tls_client_options_factory is None:
|
||||||
transport_endpoint = HostnameEndpoint
|
transport_endpoint = HostnameEndpoint
|
||||||
default_port = 8008
|
default_port = 8008
|
||||||
else:
|
else:
|
||||||
def transport_endpoint(reactor, host, port, timeout):
|
def transport_endpoint(reactor, host, port, timeout):
|
||||||
return wrapClientTLS(
|
return wrapClientTLS(
|
||||||
ssl_context_factory,
|
tls_client_options_factory.get_options(host),
|
||||||
HostnameEndpoint(reactor, host, port, timeout=timeout))
|
HostnameEndpoint(reactor, host, port, timeout=timeout))
|
||||||
default_port = 8448
|
default_port = 8448
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ from synapse.api.errors import (
|
|||||||
from synapse.http import cancelled_to_request_timed_out_error
|
from synapse.http import cancelled_to_request_timed_out_error
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext
|
||||||
from synapse.util.async import add_timeout_to_deferred
|
from synapse.util.async_helpers import add_timeout_to_deferred
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3
|
|||||||
|
|
||||||
class MatrixFederationEndpointFactory(object):
|
class MatrixFederationEndpointFactory(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.tls_server_context_factory = hs.tls_server_context_factory
|
self.tls_client_options_factory = hs.tls_client_options_factory
|
||||||
|
|
||||||
def endpointForURI(self, uri):
|
def endpointForURI(self, uri):
|
||||||
destination = uri.netloc
|
destination = uri.netloc
|
||||||
|
|
||||||
return matrix_federation_endpoint(
|
return matrix_federation_endpoint(
|
||||||
reactor, destination, timeout=10,
|
reactor, destination, timeout=10,
|
||||||
ssl_context_factory=self.tls_server_context_factory
|
tls_client_options_factory=self.tls_client_options_factory
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from synapse.api.errors import AuthError
|
|||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
from synapse.util.async import (
|
from synapse.util.async_helpers import (
|
||||||
DeferredTimeoutError,
|
DeferredTimeoutError,
|
||||||
ObservableDeferred,
|
ObservableDeferred,
|
||||||
add_timeout_to_deferred,
|
add_timeout_to_deferred,
|
||||||
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.event_auth import get_user_power_level
|
from synapse.event_auth import get_user_power_level
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ from synapse.push.presentable_names import (
|
|||||||
name_from_member_event,
|
name_from_member_event,
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.replication.http import membership, send_event
|
from synapse.replication.http import federation, membership, send_event
|
||||||
|
|
||||||
REPLICATION_PREFIX = "/_synapse/replication"
|
REPLICATION_PREFIX = "/_synapse/replication"
|
||||||
|
|
||||||
@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
|
|||||||
def register_servlets(self, hs):
|
def register_servlets(self, hs):
|
||||||
send_event.register_servlets(hs, self)
|
send_event.register_servlets(hs, self)
|
||||||
membership.register_servlets(hs, self)
|
membership.register_servlets(hs, self)
|
||||||
|
federation.register_servlets(hs, self)
|
||||||
|
259
synapse/replication/http/federation.py
Normal file
259
synapse/replication/http/federation.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
|
"""Handles events newly received from federation, including persisting and
|
||||||
|
notifying.
|
||||||
|
|
||||||
|
The API looks like:
|
||||||
|
|
||||||
|
POST /_synapse/replication/fed_send_events/:txn_id
|
||||||
|
|
||||||
|
{
|
||||||
|
"events": [{
|
||||||
|
"event": { .. serialized event .. },
|
||||||
|
"internal_metadata": { .. serialized internal_metadata .. },
|
||||||
|
"rejected_reason": .., // The event.rejected_reason field
|
||||||
|
"context": { .. serialized event context .. },
|
||||||
|
}],
|
||||||
|
"backfilled": false
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "fed_send_events"
|
||||||
|
PATH_ARGS = ()
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.federation_handler = hs.get_handlers().federation_handler
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _serialize_payload(store, event_and_contexts, backfilled):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
store
|
||||||
|
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
|
||||||
|
backfilled (bool): Whether or not the events are the result of
|
||||||
|
backfilling
|
||||||
|
"""
|
||||||
|
event_payloads = []
|
||||||
|
for event, context in event_and_contexts:
|
||||||
|
serialized_context = yield context.serialize(event, store)
|
||||||
|
|
||||||
|
event_payloads.append({
|
||||||
|
"event": event.get_pdu_json(),
|
||||||
|
"internal_metadata": event.internal_metadata.get_dict(),
|
||||||
|
"rejected_reason": event.rejected_reason,
|
||||||
|
"context": serialized_context,
|
||||||
|
})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"events": event_payloads,
|
||||||
|
"backfilled": backfilled,
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue(payload)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_request(self, request):
|
||||||
|
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
backfilled = content["backfilled"]
|
||||||
|
|
||||||
|
event_payloads = content["events"]
|
||||||
|
|
||||||
|
event_and_contexts = []
|
||||||
|
for event_payload in event_payloads:
|
||||||
|
event_dict = event_payload["event"]
|
||||||
|
internal_metadata = event_payload["internal_metadata"]
|
||||||
|
rejected_reason = event_payload["rejected_reason"]
|
||||||
|
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
|
||||||
|
|
||||||
|
context = yield EventContext.deserialize(
|
||||||
|
self.store, event_payload["context"],
|
||||||
|
)
|
||||||
|
|
||||||
|
event_and_contexts.append((event, context))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Got %d events from federation",
|
||||||
|
len(event_and_contexts),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.federation_handler.persist_events_and_notify(
|
||||||
|
event_and_contexts, backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
||||||
|
"""Handles EDUs newly received from federation, including persisting and
|
||||||
|
notifying.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
|
||||||
|
|
||||||
|
{
|
||||||
|
"origin": ...,
|
||||||
|
"content: { ... }
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "fed_send_edu"
|
||||||
|
PATH_ARGS = ("edu_type",)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_payload(edu_type, origin, content):
|
||||||
|
return {
|
||||||
|
"origin": origin,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_request(self, request, edu_type):
|
||||||
|
with Measure(self.clock, "repl_fed_send_edu_parse"):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
origin = content["origin"]
|
||||||
|
edu_content = content["content"]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Got %r edu from $s",
|
||||||
|
edu_type, origin,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = yield self.registry.on_edu(edu_type, origin, edu_content)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||||
|
"""Handle responding to queries from federation.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/fed_query/:query_type
|
||||||
|
|
||||||
|
{
|
||||||
|
"args": { ... }
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "fed_query"
|
||||||
|
PATH_ARGS = ("query_type",)
|
||||||
|
|
||||||
|
# This is a query, so let's not bother caching
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationGetQueryRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_payload(query_type, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
query_type (str)
|
||||||
|
args (dict): The arguments received for the given query type
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"args": args,
|
||||||
|
}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_request(self, request, query_type):
|
||||||
|
with Measure(self.clock, "repl_fed_query_parse"):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
args = content["args"]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Got %r query",
|
||||||
|
query_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = yield self.registry.on_query(query_type, args)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
||||||
|
"""Called to clean up any data in DB for a given room, ready for the
|
||||||
|
server to join the room.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
|
||||||
|
|
||||||
|
{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "fed_cleanup_room"
|
||||||
|
PATH_ARGS = ("room_id",)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationCleanRoomRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_payload(room_id, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_request(self, request, room_id):
|
||||||
|
yield self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
||||||
|
ReplicationFederationSendEduRestServlet(hs).register(http_server)
|
||||||
|
ReplicationGetQueryRestServlet(hs).register(http_server)
|
||||||
|
ReplicationCleanRoomRestServlet(hs).register(http_server)
|
@ -13,19 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.storage import DataStore
|
|
||||||
from synapse.storage.transactions import TransactionStore
|
from synapse.storage.transactions import TransactionStore
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
|
|
||||||
|
|
||||||
class TransactionStore(BaseSlavedStore):
|
class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
|
||||||
get_destination_retry_timings = TransactionStore.__dict__[
|
pass
|
||||||
"get_destination_retry_timings"
|
|
||||||
]
|
|
||||||
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
|
||||||
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
|
|
||||||
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
|
|
||||||
|
|
||||||
prep_send_transaction = DataStore.prep_send_transaction.__func__
|
|
||||||
delivered_txn = DataStore.delivered_txn.__func__
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
to ensure idempotency when performing PUTs using the REST API."""
|
to ensure idempotency when performing PUTs using the REST API."""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||||||
if not is_admin:
|
if not is_admin:
|
||||||
raise AuthError(403, "You are not a server admin")
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
yield self._deactivate_account_handler.deactivate_account(
|
result = yield self._deactivate_account_handler.deactivate_account(
|
||||||
target_user_id, erase,
|
target_user_id, erase,
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
if result:
|
||||||
|
id_server_unbind_result = "success"
|
||||||
|
else:
|
||||||
|
id_server_unbind_result = "no-support"
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"id_server_unbind_result": id_server_unbind_result,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
class ShutdownRoomRestServlet(ClientV1RestServlet):
|
class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||||
|
@ -209,10 +209,17 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||||||
yield self.auth_handler.validate_user_via_ui_auth(
|
yield self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester, body, self.hs.get_ip_from_request(request),
|
requester, body, self.hs.get_ip_from_request(request),
|
||||||
)
|
)
|
||||||
yield self._deactivate_account_handler.deactivate_account(
|
result = yield self._deactivate_account_handler.deactivate_account(
|
||||||
requester.user.to_string(), erase,
|
requester.user.to_string(), erase,
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
if result:
|
||||||
|
id_server_unbind_result = "success"
|
||||||
|
else:
|
||||||
|
id_server_unbind_result = "no-support"
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"id_server_unbind_result": id_server_unbind_result,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
class EmailThreepidRequestTokenRestServlet(RestServlet):
|
class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
@ -364,7 +371,7 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.auth_handler.delete_threepid(
|
ret = yield self.auth_handler.delete_threepid(
|
||||||
user_id, body['medium'], body['address']
|
user_id, body['medium'], body['address']
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -374,7 +381,14 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||||||
logger.exception("Failed to remove threepid")
|
logger.exception("Failed to remove threepid")
|
||||||
raise SynapseError(500, "Failed to remove threepid")
|
raise SynapseError(500, "Failed to remove threepid")
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
if ret:
|
||||||
|
id_server_unbind_result = "success"
|
||||||
|
else:
|
||||||
|
id_server_unbind_result = "no-support"
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"id_server_unbind_result": id_server_unbind_result,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
class WhoamiRestServlet(RestServlet):
|
class WhoamiRestServlet(RestServlet):
|
||||||
|
@ -36,7 +36,7 @@ from synapse.api.errors import (
|
|||||||
)
|
)
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
from synapse.util.stringutils import is_ascii, random_string
|
from synapse.util.stringutils import is_ascii, random_string
|
||||||
|
@ -42,7 +42,7 @@ from synapse.http.server import (
|
|||||||
)
|
)
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
from synapse.util.stringutils import is_ascii, random_string
|
from synapse.util.stringutils import is_ascii, random_string
|
||||||
|
@ -36,6 +36,7 @@ from synapse.federation.federation_client import FederationClient
|
|||||||
from synapse.federation.federation_server import (
|
from synapse.federation.federation_server import (
|
||||||
FederationHandlerRegistry,
|
FederationHandlerRegistry,
|
||||||
FederationServer,
|
FederationServer,
|
||||||
|
ReplicationFederationHandlerRegistry,
|
||||||
)
|
)
|
||||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||||
from synapse.federation.transaction_queue import TransactionQueue
|
from synapse.federation.transaction_queue import TransactionQueue
|
||||||
@ -423,7 +424,10 @@ class HomeServer(object):
|
|||||||
return RoomMemberMasterHandler(self)
|
return RoomMemberMasterHandler(self)
|
||||||
|
|
||||||
def build_federation_registry(self):
|
def build_federation_registry(self):
|
||||||
return FederationHandlerRegistry()
|
if self.config.worker_app:
|
||||||
|
return ReplicationFederationHandlerRegistry(self)
|
||||||
|
else:
|
||||||
|
return FederationHandlerRegistry()
|
||||||
|
|
||||||
def build_server_notices_manager(self):
|
def build_server_notices_manager(self):
|
||||||
if self.config.worker_app:
|
if self.config.worker_app:
|
||||||
|
@ -28,7 +28,7 @@ from synapse import event_auth
|
|||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||||
<link rel="stylesheet" href="style.css">
|
<link rel="stylesheet" href="style.css">
|
||||||
<script src="js/jquery-2.1.3.min.js"></script>
|
<script src="js/jquery-2.1.3.min.js"></script>
|
||||||
<script src="js/recaptcha_ajax.js"></script>
|
<script src="https://www.google.com/recaptcha/api/js/recaptcha_ajax.js"></script>
|
||||||
<script src="register_config.js"></script>
|
<script src="register_config.js"></script>
|
||||||
<script src="js/register.js"></script>
|
<script src="js/register.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
File diff suppressed because one or more lines are too long
@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
self._batch_row_update[key] = (user_agent, device_id, now)
|
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||||
|
|
||||||
def _update_client_ips_batch(self):
|
def _update_client_ips_batch(self):
|
||||||
|
|
||||||
|
# If the DB pool has already terminated, don't try updating
|
||||||
|
if not self.hs.get_db_pool().running:
|
||||||
|
return
|
||||||
|
|
||||||
def update():
|
def update():
|
||||||
to_update = self._batch_row_update
|
to_update = self._batch_row_update
|
||||||
self._batch_row_update = {}
|
self._batch_row_update = {}
|
||||||
|
@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
|
|||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
from synapse.storage.events_worker import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
from synapse.util.frozenutils import frozendict_json_encoder
|
||||||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
||||||
@ -1435,88 +1435,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
|||||||
(event.event_id, event.redacts)
|
(event.event_id, event.redacts)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def have_events_in_timeline(self, event_ids):
|
|
||||||
"""Given a list of event ids, check if we have already processed and
|
|
||||||
stored them as non outliers.
|
|
||||||
"""
|
|
||||||
rows = yield self._simple_select_many_batch(
|
|
||||||
table="events",
|
|
||||||
retcols=("event_id",),
|
|
||||||
column="event_id",
|
|
||||||
iterable=list(event_ids),
|
|
||||||
keyvalues={"outlier": False},
|
|
||||||
desc="have_events_in_timeline",
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(set(r["event_id"] for r in rows))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def have_seen_events(self, event_ids):
|
|
||||||
"""Given a list of event ids, check if we have already processed them.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_ids (iterable[str]):
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[set[str]]: The events we have already seen.
|
|
||||||
"""
|
|
||||||
results = set()
|
|
||||||
|
|
||||||
def have_seen_events_txn(txn, chunk):
|
|
||||||
sql = (
|
|
||||||
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
|
|
||||||
% (",".join("?" * len(chunk)), )
|
|
||||||
)
|
|
||||||
txn.execute(sql, chunk)
|
|
||||||
for (event_id, ) in txn:
|
|
||||||
results.add(event_id)
|
|
||||||
|
|
||||||
# break the input up into chunks of 100
|
|
||||||
input_iterator = iter(event_ids)
|
|
||||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
|
|
||||||
[]):
|
|
||||||
yield self.runInteraction(
|
|
||||||
"have_seen_events",
|
|
||||||
have_seen_events_txn,
|
|
||||||
chunk,
|
|
||||||
)
|
|
||||||
defer.returnValue(results)
|
|
||||||
|
|
||||||
def get_seen_events_with_rejections(self, event_ids):
|
|
||||||
"""Given a list of event ids, check if we rejected them.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_ids (list[str])
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[dict[str, str|None):
|
|
||||||
Has an entry for each event id we already have seen. Maps to
|
|
||||||
the rejected reason string if we rejected the event, else maps
|
|
||||||
to None.
|
|
||||||
"""
|
|
||||||
if not event_ids:
|
|
||||||
return defer.succeed({})
|
|
||||||
|
|
||||||
def f(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT e.event_id, reason FROM events as e "
|
|
||||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
|
||||||
"WHERE e.event_id = ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
res = {}
|
|
||||||
for event_id in event_ids:
|
|
||||||
txn.execute(sql, (event_id,))
|
|
||||||
row = txn.fetchone()
|
|
||||||
if row:
|
|
||||||
_, rejected = row
|
|
||||||
res[event_id] = rejected
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
return self.runInteraction("get_rejection_reasons", f)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def count_daily_messages(self):
|
def count_daily_messages(self):
|
||||||
"""
|
"""
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
||||||
|
|
||||||
defer.returnValue(cache_entry)
|
defer.returnValue(cache_entry)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def have_events_in_timeline(self, event_ids):
|
||||||
|
"""Given a list of event ids, check if we have already processed and
|
||||||
|
stored them as non outliers.
|
||||||
|
"""
|
||||||
|
rows = yield self._simple_select_many_batch(
|
||||||
|
table="events",
|
||||||
|
retcols=("event_id",),
|
||||||
|
column="event_id",
|
||||||
|
iterable=list(event_ids),
|
||||||
|
keyvalues={"outlier": False},
|
||||||
|
desc="have_events_in_timeline",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(set(r["event_id"] for r in rows))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def have_seen_events(self, event_ids):
|
||||||
|
"""Given a list of event ids, check if we have already processed them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids (iterable[str]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[set[str]]: The events we have already seen.
|
||||||
|
"""
|
||||||
|
results = set()
|
||||||
|
|
||||||
|
def have_seen_events_txn(txn, chunk):
|
||||||
|
sql = (
|
||||||
|
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
|
||||||
|
% (",".join("?" * len(chunk)), )
|
||||||
|
)
|
||||||
|
txn.execute(sql, chunk)
|
||||||
|
for (event_id, ) in txn:
|
||||||
|
results.add(event_id)
|
||||||
|
|
||||||
|
# break the input up into chunks of 100
|
||||||
|
input_iterator = iter(event_ids)
|
||||||
|
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
|
||||||
|
[]):
|
||||||
|
yield self.runInteraction(
|
||||||
|
"have_seen_events",
|
||||||
|
have_seen_events_txn,
|
||||||
|
chunk,
|
||||||
|
)
|
||||||
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
def get_seen_events_with_rejections(self, event_ids):
|
||||||
|
"""Given a list of event ids, check if we rejected them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids (list[str])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[str, str|None):
|
||||||
|
Has an entry for each event id we already have seen. Maps to
|
||||||
|
the rejected reason string if we rejected the event, else maps
|
||||||
|
to None.
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return defer.succeed({})
|
||||||
|
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT e.event_id, reason FROM events as e "
|
||||||
|
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||||
|
"WHERE e.event_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
res = {}
|
||||||
|
for event_id in event_ids:
|
||||||
|
txn.execute(sql, (event_id,))
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row:
|
||||||
|
_, rejected = row
|
||||||
|
res[event_id] = rejected
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return self.runInteraction("get_rejection_reasons", f)
|
||||||
|
@ -46,7 +46,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||||||
tp["medium"], tp["address"]
|
tp["medium"], tp["address"]
|
||||||
)
|
)
|
||||||
if user_id:
|
if user_id:
|
||||||
self.upsert_monthly_active_user(user_id)
|
yield self.upsert_monthly_active_user(user_id)
|
||||||
reserved_user_list.append(user_id)
|
reserved_user_list.append(user_id)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -64,23 +64,27 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||||||
Deferred[]
|
Deferred[]
|
||||||
"""
|
"""
|
||||||
def _reap_users(txn):
|
def _reap_users(txn):
|
||||||
|
# Purge stale users
|
||||||
|
|
||||||
thirty_days_ago = (
|
thirty_days_ago = (
|
||||||
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||||
)
|
)
|
||||||
# Purge stale users
|
|
||||||
|
|
||||||
# questionmarks is a hack to overcome sqlite not supporting
|
|
||||||
# tuples in 'WHERE IN %s'
|
|
||||||
questionmarks = '?' * len(self.reserved_users)
|
|
||||||
query_args = [thirty_days_ago]
|
query_args = [thirty_days_ago]
|
||||||
query_args.extend(self.reserved_users)
|
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
|
||||||
|
|
||||||
sql = """
|
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
|
||||||
DELETE FROM monthly_active_users
|
# when len(reserved_users) == 0. Works fine on sqlite.
|
||||||
WHERE timestamp < ?
|
if len(self.reserved_users) > 0:
|
||||||
AND user_id NOT IN ({})
|
# questionmarks is a hack to overcome sqlite not supporting
|
||||||
""".format(','.join(questionmarks))
|
# tuples in 'WHERE IN %s'
|
||||||
|
questionmarks = '?' * len(self.reserved_users)
|
||||||
|
|
||||||
|
query_args.extend(self.reserved_users)
|
||||||
|
sql = base_sql + """ AND user_id NOT IN ({})""".format(
|
||||||
|
','.join(questionmarks)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = base_sql
|
||||||
|
|
||||||
txn.execute(sql, query_args)
|
txn.execute(sql, query_args)
|
||||||
|
|
||||||
@ -93,16 +97,24 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||||||
# negative LIMIT values. So there is no way to write it that both can
|
# negative LIMIT values. So there is no way to write it that both can
|
||||||
# support
|
# support
|
||||||
query_args = [self.hs.config.max_mau_value]
|
query_args = [self.hs.config.max_mau_value]
|
||||||
query_args.extend(self.reserved_users)
|
|
||||||
sql = """
|
base_sql = """
|
||||||
DELETE FROM monthly_active_users
|
DELETE FROM monthly_active_users
|
||||||
WHERE user_id NOT IN (
|
WHERE user_id NOT IN (
|
||||||
SELECT user_id FROM monthly_active_users
|
SELECT user_id FROM monthly_active_users
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
)
|
)
|
||||||
AND user_id NOT IN ({})
|
"""
|
||||||
""".format(','.join(questionmarks))
|
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
|
||||||
|
# when len(reserved_users) == 0. Works fine on sqlite.
|
||||||
|
if len(self.reserved_users) > 0:
|
||||||
|
query_args.extend(self.reserved_users)
|
||||||
|
sql = base_sql + """ AND user_id NOT IN ({})""".format(
|
||||||
|
','.join(questionmarks)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = base_sql
|
||||||
txn.execute(sql, query_args)
|
txn.execute(sql, query_args)
|
||||||
|
|
||||||
yield self.runInteraction("reap_monthly_active_users", _reap_users)
|
yield self.runInteraction("reap_monthly_active_users", _reap_users)
|
||||||
|
@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
|
|||||||
|
|
||||||
|
|
||||||
class RoomWorkerStore(SQLBaseStore):
|
class RoomWorkerStore(SQLBaseStore):
|
||||||
|
def get_room(self, room_id):
|
||||||
|
"""Retrieve a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): The ID of the room to retrieve.
|
||||||
|
Returns:
|
||||||
|
A namedtuple containing the room information, or an empty list.
|
||||||
|
"""
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=("room_id", "is_public", "creator"),
|
||||||
|
desc="get_room",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
def get_public_room_ids(self):
|
def get_public_room_ids(self):
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
table="rooms",
|
table="rooms",
|
||||||
@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
|||||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||||
raise StoreError(500, "Problem creating room.")
|
raise StoreError(500, "Problem creating room.")
|
||||||
|
|
||||||
def get_room(self, room_id):
|
|
||||||
"""Retrieve a room.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
room_id (str): The ID of the room to retrieve.
|
|
||||||
Returns:
|
|
||||||
A namedtuple containing the room information, or an empty list.
|
|
||||||
"""
|
|
||||||
return self._simple_select_one(
|
|
||||||
table="rooms",
|
|
||||||
keyvalues={"room_id": room_id},
|
|
||||||
retcols=("room_id", "is_public", "creator"),
|
|
||||||
desc="get_room",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_room_is_public(self, room_id, is_public):
|
def set_room_is_public(self, room_id, is_public):
|
||||||
def set_room_is_public_txn(txn, next_id):
|
def set_room_is_public_txn(txn, next_id):
|
||||||
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.storage.events_worker import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
from synapse.util.stringutils import to_ascii
|
from synapse.util.stringutils import to_ascii
|
||||||
|
@ -188,62 +188,30 @@ class Linearizer(object):
|
|||||||
# things blocked from executing.
|
# things blocked from executing.
|
||||||
self.key_to_defer = {}
|
self.key_to_defer = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def queue(self, key):
|
def queue(self, key):
|
||||||
|
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
|
||||||
|
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
|
||||||
|
# propagated inside inlineCallbacks until Twisted 18.7)
|
||||||
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
|
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
|
||||||
|
|
||||||
# If the number of things executing is greater than the maximum
|
# If the number of things executing is greater than the maximum
|
||||||
# then add a deferred to the list of blocked items
|
# then add a deferred to the list of blocked items
|
||||||
# When on of the things currently executing finishes it will callback
|
# When one of the things currently executing finishes it will callback
|
||||||
# this item so that it can continue executing.
|
# this item so that it can continue executing.
|
||||||
if entry[0] >= self.max_count:
|
if entry[0] >= self.max_count:
|
||||||
new_defer = defer.Deferred()
|
res = self._await_lock(key)
|
||||||
entry[1][new_defer] = 1
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
yield make_deferred_yieldable(new_defer)
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, CancelledError):
|
|
||||||
logger.info(
|
|
||||||
"Cancelling wait for linearizer lock %r for key %r",
|
|
||||||
self.name, key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
"Unexpected exception waiting for linearizer lock %r for key %r",
|
|
||||||
self.name, key,
|
|
||||||
)
|
|
||||||
|
|
||||||
# we just have to take ourselves back out of the queue.
|
|
||||||
del entry[1][new_defer]
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
|
||||||
entry[0] += 1
|
|
||||||
|
|
||||||
# if the code holding the lock completes synchronously, then it
|
|
||||||
# will recursively run the next claimant on the list. That can
|
|
||||||
# relatively rapidly lead to stack exhaustion. This is essentially
|
|
||||||
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
|
||||||
#
|
|
||||||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
|
||||||
# ensure that we fall back to the reactor between each iteration.
|
|
||||||
#
|
|
||||||
# (This needs to happen while we hold the lock, and the context manager's exit
|
|
||||||
# code must be synchronous, so this is the only sensible place.)
|
|
||||||
yield self._clock.sleep(0)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
|
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
|
||||||
)
|
)
|
||||||
entry[0] += 1
|
entry[0] += 1
|
||||||
|
res = defer.succeed(None)
|
||||||
|
|
||||||
|
# once we successfully get the lock, we need to return a context manager which
|
||||||
|
# will release the lock.
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager(_):
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -264,7 +232,64 @@ class Linearizer(object):
|
|||||||
# map.
|
# map.
|
||||||
del self.key_to_defer[key]
|
del self.key_to_defer[key]
|
||||||
|
|
||||||
defer.returnValue(_ctx_manager())
|
res.addCallback(_ctx_manager)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _await_lock(self, key):
|
||||||
|
"""Helper for queue: adds a deferred to the queue
|
||||||
|
|
||||||
|
Assumes that we've already checked that we've reached the limit of the number
|
||||||
|
of lock-holders we allow. Creates a new deferred which is added to the list, and
|
||||||
|
adds some management around cancellations.
|
||||||
|
|
||||||
|
Returns the deferred, which will callback once we have secured the lock.
|
||||||
|
|
||||||
|
"""
|
||||||
|
entry = self.key_to_defer[key]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_defer = make_deferred_yieldable(defer.Deferred())
|
||||||
|
entry[1][new_defer] = 1
|
||||||
|
|
||||||
|
def cb(_r):
|
||||||
|
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||||
|
entry[0] += 1
|
||||||
|
|
||||||
|
# if the code holding the lock completes synchronously, then it
|
||||||
|
# will recursively run the next claimant on the list. That can
|
||||||
|
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||||
|
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||||
|
#
|
||||||
|
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||||
|
# ensure that we fall back to the reactor between each iteration.
|
||||||
|
#
|
||||||
|
# (This needs to happen while we hold the lock, and the context manager's exit
|
||||||
|
# code must be synchronous, so this is the only sensible place.)
|
||||||
|
return self._clock.sleep(0)
|
||||||
|
|
||||||
|
def eb(e):
|
||||||
|
logger.info("defer %r got err %r", new_defer, e)
|
||||||
|
if isinstance(e, CancelledError):
|
||||||
|
logger.info(
|
||||||
|
"Cancelling wait for linearizer lock %r for key %r",
|
||||||
|
self.name, key,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"Unexpected exception waiting for linearizer lock %r for key %r",
|
||||||
|
self.name, key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# we just have to take ourselves back out of the queue.
|
||||||
|
del entry[1][new_defer]
|
||||||
|
return e
|
||||||
|
|
||||||
|
new_defer.addCallbacks(cb, eb)
|
||||||
|
return new_defer
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock(object):
|
class ReadWriteLock(object):
|
@ -25,7 +25,7 @@ from six import itervalues, string_types
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util import logcontext, unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches import get_cache_factor_for
|
from synapse.util.caches import get_cache_factor_for
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
|
@ -16,7 +16,7 @@ import logging
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
|
|
||||||
|
|
||||||
class SnapshotCache(object):
|
class SnapshotCache(object):
|
||||||
|
@ -526,7 +526,7 @@ _to_ignore = [
|
|||||||
"synapse.util.logcontext",
|
"synapse.util.logcontext",
|
||||||
"synapse.http.server",
|
"synapse.http.server",
|
||||||
"synapse.storage._base",
|
"synapse.storage._base",
|
||||||
"synapse.util.async",
|
"synapse.util.async_helpers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,4 +15,7 @@
|
|||||||
|
|
||||||
from twisted.trial import util
|
from twisted.trial import util
|
||||||
|
|
||||||
|
from tests import utils
|
||||||
|
|
||||||
util.DEFAULT_TIMEOUT_DURATION = 10
|
util.DEFAULT_TIMEOUT_DURATION = 10
|
||||||
|
utils.setupdb()
|
||||||
|
@ -34,13 +34,12 @@ class TestHandlers(object):
|
|||||||
|
|
||||||
|
|
||||||
class AuthTestCase(unittest.TestCase):
|
class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.state_handler = Mock()
|
self.state_handler = Mock()
|
||||||
self.store = Mock()
|
self.store = Mock()
|
||||||
|
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
|
||||||
self.hs.get_datastore = Mock(return_value=self.store)
|
self.hs.get_datastore = Mock(return_value=self.store)
|
||||||
self.hs.handlers = TestHandlers(self.hs)
|
self.hs.handlers = TestHandlers(self.hs)
|
||||||
self.auth = Auth(self.hs)
|
self.auth = Auth(self.hs)
|
||||||
@ -53,11 +52,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_user_valid_token(self):
|
def test_get_user_by_req_user_valid_token(self):
|
||||||
user_info = {
|
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
||||||
"name": self.test_user,
|
|
||||||
"token_id": "ditto",
|
|
||||||
"device_id": "device",
|
|
||||||
}
|
|
||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
@ -76,10 +71,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self):
|
def test_get_user_by_req_user_missing_token(self):
|
||||||
user_info = {
|
user_info = {"name": self.test_user, "token_id": "ditto"}
|
||||||
"name": self.test_user,
|
|
||||||
"token_id": "ditto",
|
|
||||||
}
|
|
||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
@ -90,8 +82,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_appservice_valid_token(self):
|
def test_get_user_by_req_appservice_valid_token(self):
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
ip_range_whitelist=None,
|
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
@ -106,8 +97,11 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_appservice_valid_token_good_ip(self):
|
def test_get_user_by_req_appservice_valid_token_good_ip(self):
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar",
|
||||||
|
url="a_url",
|
||||||
|
sender=self.test_user,
|
||||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
@ -122,8 +116,11 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
|
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar",
|
||||||
|
url="a_url",
|
||||||
|
sender=self.test_user,
|
||||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
@ -160,8 +157,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
||||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
ip_range_whitelist=None,
|
|
||||||
)
|
)
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
@ -174,15 +170,13 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(), masquerading_user_id.decode('utf8')
|
||||||
masquerading_user_id.decode('utf8')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
ip_range_whitelist=None,
|
|
||||||
)
|
)
|
||||||
app_service.is_interested_in_user = Mock(return_value=False)
|
app_service.is_interested_in_user = Mock(return_value=False)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
@ -201,17 +195,15 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
# TODO(danielwh): Remove this mock when we remove the
|
# TODO(danielwh): Remove this mock when we remove the
|
||||||
# get_user_by_access_token fallback.
|
# get_user_by_access_token fallback.
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value={
|
return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
|
||||||
"name": "@baldrick:matrix.org",
|
|
||||||
"device_id": "device",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
@ -225,15 +217,14 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_guest_user_from_macaroon(self):
|
def test_get_guest_user_from_macaroon(self):
|
||||||
self.store.get_user_by_id = Mock(return_value={
|
self.store.get_user_by_id = Mock(return_value={"is_guest": True})
|
||||||
"is_guest": True,
|
|
||||||
})
|
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
@ -257,7 +248,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
@ -277,7 +269,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
|
||||||
@ -298,7 +291,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key + "wrong")
|
key=self.hs.config.macaroon_secret_key + "wrong",
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
@ -320,7 +314,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
@ -347,7 +342,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
@ -380,7 +376,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.hs.config.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key,
|
||||||
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
@ -401,9 +398,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
token = yield self.hs.handlers.auth_handler.issue_access_token(
|
token = yield self.hs.handlers.auth_handler.issue_access_token(
|
||||||
USER_ID, "DEVICE"
|
USER_ID, "DEVICE"
|
||||||
)
|
)
|
||||||
self.store.add_access_token_to_user.assert_called_with(
|
self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
|
||||||
USER_ID, token, "DEVICE"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_user(tok):
|
def get_user(tok):
|
||||||
if token != tok:
|
if token != tok:
|
||||||
@ -414,10 +409,9 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
"token_id": 1234,
|
"token_id": 1234,
|
||||||
"device_id": "DEVICE",
|
"device_id": "DEVICE",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.store.get_user_by_access_token = get_user
|
self.store.get_user_by_access_token = get_user
|
||||||
self.store.get_user_by_id = Mock(return_value={
|
self.store.get_user_by_id = Mock(return_value={"is_guest": False})
|
||||||
"is_guest": False,
|
|
||||||
})
|
|
||||||
|
|
||||||
# check the token works
|
# check the token works
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
@ -461,8 +455,11 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
return_value=defer.succeed(lots_of_users)
|
return_value=defer.succeed(lots_of_users)
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(AuthError) as e:
|
||||||
yield self.auth.check_auth_blocking()
|
yield self.auth.check_auth_blocking()
|
||||||
|
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||||
|
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||||
|
self.assertEquals(e.exception.code, 403)
|
||||||
|
|
||||||
# Ensure does not throw an error
|
# Ensure does not throw an error
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
@ -476,5 +473,6 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||||
with self.assertRaises(AuthError) as e:
|
with self.assertRaises(AuthError) as e:
|
||||||
yield self.auth.check_auth_blocking()
|
yield self.auth.check_auth_blocking()
|
||||||
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
|
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||||
|
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||||
self.assertEquals(e.exception.code, 403)
|
self.assertEquals(e.exception.code, 403)
|
||||||
|
@ -38,7 +38,6 @@ def MockEvent(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class FilteringTestCase(unittest.TestCase):
|
class FilteringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.mock_federation_resource = MockHttpResource()
|
self.mock_federation_resource = MockHttpResource()
|
||||||
@ -47,6 +46,7 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
self.mock_http_client.put_json = DeferredMockCallable()
|
self.mock_http_client.put_json = DeferredMockCallable()
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
|
self.addCleanup,
|
||||||
handlers=None,
|
handlers=None,
|
||||||
http_client=self.mock_http_client,
|
http_client=self.mock_http_client,
|
||||||
keyring=Mock(),
|
keyring=Mock(),
|
||||||
@ -64,7 +64,7 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
|
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
|
||||||
{"event_format": "other"},
|
{"event_format": "other"},
|
||||||
{"room": {"not_rooms": ["#foo:pik-test"]}},
|
{"room": {"not_rooms": ["#foo:pik-test"]}},
|
||||||
{"presence": {"senders": ["@bar;pik.test.com"]}}
|
{"presence": {"senders": ["@bar;pik.test.com"]}},
|
||||||
]
|
]
|
||||||
for filter in invalid_filters:
|
for filter in invalid_filters:
|
||||||
with self.assertRaises(SynapseError) as check_filter_error:
|
with self.assertRaises(SynapseError) as check_filter_error:
|
||||||
@ -81,34 +81,34 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
"include_leave": False,
|
"include_leave": False,
|
||||||
"rooms": ["!dee:pik-test"],
|
"rooms": ["!dee:pik-test"],
|
||||||
"not_rooms": ["!gee:pik-test"],
|
"not_rooms": ["!gee:pik-test"],
|
||||||
"account_data": {"limit": 0, "types": ["*"]}
|
"account_data": {"limit": 0, "types": ["*"]},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"room": {
|
"room": {
|
||||||
"state": {
|
"state": {
|
||||||
"types": ["m.room.*"],
|
"types": ["m.room.*"],
|
||||||
"not_rooms": ["!726s6s6q:example.com"]
|
"not_rooms": ["!726s6s6q:example.com"],
|
||||||
},
|
},
|
||||||
"timeline": {
|
"timeline": {
|
||||||
"limit": 10,
|
"limit": 10,
|
||||||
"types": ["m.room.message"],
|
"types": ["m.room.message"],
|
||||||
"not_rooms": ["!726s6s6q:example.com"],
|
"not_rooms": ["!726s6s6q:example.com"],
|
||||||
"not_senders": ["@spam:example.com"]
|
"not_senders": ["@spam:example.com"],
|
||||||
},
|
},
|
||||||
"ephemeral": {
|
"ephemeral": {
|
||||||
"types": ["m.receipt", "m.typing"],
|
"types": ["m.receipt", "m.typing"],
|
||||||
"not_rooms": ["!726s6s6q:example.com"],
|
"not_rooms": ["!726s6s6q:example.com"],
|
||||||
"not_senders": ["@spam:example.com"]
|
"not_senders": ["@spam:example.com"],
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"presence": {
|
"presence": {
|
||||||
"types": ["m.presence"],
|
"types": ["m.presence"],
|
||||||
"not_senders": ["@alice:example.com"]
|
"not_senders": ["@alice:example.com"],
|
||||||
},
|
},
|
||||||
"event_format": "client",
|
"event_format": "client",
|
||||||
"event_fields": ["type", "content", "sender"]
|
"event_fields": ["type", "content", "sender"],
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
for filter in valid_filters:
|
for filter in valid_filters:
|
||||||
try:
|
try:
|
||||||
@ -121,229 +121,131 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_definition_types_works_with_literals(self):
|
def test_definition_types_works_with_literals(self):
|
||||||
definition = {
|
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
"types": ["m.room.message", "org.matrix.foo.bar"]
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
}
|
|
||||||
event = MockEvent(
|
|
||||||
sender="@foo:bar",
|
|
||||||
type="m.room.message",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(Filter(definition).check(event))
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_definition_types_works_with_wildcards(self):
|
def test_definition_types_works_with_wildcards(self):
|
||||||
definition = {
|
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
|
||||||
"types": ["m.*", "org.matrix.foo.bar"]
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
}
|
self.assertTrue(Filter(definition).check(event))
|
||||||
event = MockEvent(
|
|
||||||
sender="@foo:bar",
|
|
||||||
type="m.room.message",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_definition_types_works_with_unknowns(self):
|
def test_definition_types_works_with_unknowns(self):
|
||||||
definition = {
|
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
"types": ["m.room.message", "org.matrix.foo.bar"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
type="now.for.something.completely.different",
|
type="now.for.something.completely.different",
|
||||||
room_id="!foo:bar"
|
room_id="!foo:bar",
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_literals(self):
|
def test_definition_not_types_works_with_literals(self):
|
||||||
definition = {
|
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
"not_types": ["m.room.message", "org.matrix.foo.bar"]
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
}
|
self.assertFalse(Filter(definition).check(event))
|
||||||
event = MockEvent(
|
|
||||||
sender="@foo:bar",
|
|
||||||
type="m.room.message",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_definition_not_types_works_with_wildcards(self):
|
def test_definition_not_types_works_with_wildcards(self):
|
||||||
definition = {
|
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
|
||||||
"not_types": ["m.room.message", "org.matrix.*"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
||||||
type="org.matrix.custom.event",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_unknowns(self):
|
def test_definition_not_types_works_with_unknowns(self):
|
||||||
definition = {
|
definition = {"not_types": ["m.*", "org.*"]}
|
||||||
"not_types": ["m.*", "org.*"]
|
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
|
||||||
}
|
self.assertTrue(Filter(definition).check(event))
|
||||||
event = MockEvent(
|
|
||||||
sender="@foo:bar",
|
|
||||||
type="com.nom.nom.nom",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_definition_not_types_takes_priority_over_types(self):
|
def test_definition_not_types_takes_priority_over_types(self):
|
||||||
definition = {
|
definition = {
|
||||||
"not_types": ["m.*", "org.*"],
|
"not_types": ["m.*", "org.*"],
|
||||||
"types": ["m.room.message", "m.room.topic"]
|
"types": ["m.room.message", "m.room.topic"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||||
sender="@foo:bar",
|
self.assertFalse(Filter(definition).check(event))
|
||||||
type="m.room.topic",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_definition_senders_works_with_literals(self):
|
def test_definition_senders_works_with_literals(self):
|
||||||
definition = {
|
definition = {"senders": ["@flibble:wibble"]}
|
||||||
"senders": ["@flibble:wibble"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@flibble:wibble",
|
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
type="com.nom.nom.nom",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_senders_works_with_unknowns(self):
|
def test_definition_senders_works_with_unknowns(self):
|
||||||
definition = {
|
definition = {"senders": ["@flibble:wibble"]}
|
||||||
"senders": ["@flibble:wibble"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@challenger:appears",
|
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
type="com.nom.nom.nom",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_works_with_literals(self):
|
def test_definition_not_senders_works_with_literals(self):
|
||||||
definition = {
|
definition = {"not_senders": ["@flibble:wibble"]}
|
||||||
"not_senders": ["@flibble:wibble"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@flibble:wibble",
|
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
type="com.nom.nom.nom",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_works_with_unknowns(self):
|
def test_definition_not_senders_works_with_unknowns(self):
|
||||||
definition = {
|
definition = {"not_senders": ["@flibble:wibble"]}
|
||||||
"not_senders": ["@flibble:wibble"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@challenger:appears",
|
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
type="com.nom.nom.nom",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_takes_priority_over_senders(self):
|
def test_definition_not_senders_takes_priority_over_senders(self):
|
||||||
definition = {
|
definition = {
|
||||||
"not_senders": ["@misspiggy:muppets"],
|
"not_senders": ["@misspiggy:muppets"],
|
||||||
"senders": ["@kermit:muppets", "@misspiggy:muppets"]
|
"senders": ["@kermit:muppets", "@misspiggy:muppets"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@misspiggy:muppets",
|
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
|
||||||
type="m.room.topic",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_rooms_works_with_literals(self):
|
def test_definition_rooms_works_with_literals(self):
|
||||||
definition = {
|
definition = {"rooms": ["!secretbase:unknown"]}
|
||||||
"rooms": ["!secretbase:unknown"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||||
type="m.room.message",
|
|
||||||
room_id="!secretbase:unknown"
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_rooms_works_with_unknowns(self):
|
def test_definition_rooms_works_with_unknowns(self):
|
||||||
definition = {
|
definition = {"rooms": ["!secretbase:unknown"]}
|
||||||
"rooms": ["!secretbase:unknown"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!anothersecretbase:unknown"
|
room_id="!anothersecretbase:unknown",
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_works_with_literals(self):
|
def test_definition_not_rooms_works_with_literals(self):
|
||||||
definition = {
|
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
|
||||||
"not_rooms": ["!anothersecretbase:unknown"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!anothersecretbase:unknown"
|
room_id="!anothersecretbase:unknown",
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_works_with_unknowns(self):
|
def test_definition_not_rooms_works_with_unknowns(self):
|
||||||
definition = {
|
definition = {"not_rooms": ["!secretbase:unknown"]}
|
||||||
"not_rooms": ["!secretbase:unknown"]
|
|
||||||
}
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!anothersecretbase:unknown"
|
room_id="!anothersecretbase:unknown",
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_takes_priority_over_rooms(self):
|
def test_definition_not_rooms_takes_priority_over_rooms(self):
|
||||||
definition = {
|
definition = {
|
||||||
"not_rooms": ["!secretbase:unknown"],
|
"not_rooms": ["!secretbase:unknown"],
|
||||||
"rooms": ["!secretbase:unknown"]
|
"rooms": ["!secretbase:unknown"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||||
type="m.room.message",
|
|
||||||
room_id="!secretbase:unknown"
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_combined_event(self):
|
def test_definition_combined_event(self):
|
||||||
definition = {
|
definition = {
|
||||||
@ -352,16 +254,14 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
"rooms": ["!stage:unknown"],
|
"rooms": ["!stage:unknown"],
|
||||||
"not_rooms": ["!piggyshouse:muppets"],
|
"not_rooms": ["!piggyshouse:muppets"],
|
||||||
"types": ["m.room.message", "muppets.kermit.*"],
|
"types": ["m.room.message", "muppets.kermit.*"],
|
||||||
"not_types": ["muppets.misspiggy.*"]
|
"not_types": ["muppets.misspiggy.*"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@kermit:muppets", # yup
|
sender="@kermit:muppets", # yup
|
||||||
type="m.room.message", # yup
|
type="m.room.message", # yup
|
||||||
room_id="!stage:unknown" # yup
|
room_id="!stage:unknown", # yup
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_sender(self):
|
def test_definition_combined_event_bad_sender(self):
|
||||||
definition = {
|
definition = {
|
||||||
@ -370,16 +270,14 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
"rooms": ["!stage:unknown"],
|
"rooms": ["!stage:unknown"],
|
||||||
"not_rooms": ["!piggyshouse:muppets"],
|
"not_rooms": ["!piggyshouse:muppets"],
|
||||||
"types": ["m.room.message", "muppets.kermit.*"],
|
"types": ["m.room.message", "muppets.kermit.*"],
|
||||||
"not_types": ["muppets.misspiggy.*"]
|
"not_types": ["muppets.misspiggy.*"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@misspiggy:muppets", # nope
|
sender="@misspiggy:muppets", # nope
|
||||||
type="m.room.message", # yup
|
type="m.room.message", # yup
|
||||||
room_id="!stage:unknown" # yup
|
room_id="!stage:unknown", # yup
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_room(self):
|
def test_definition_combined_event_bad_room(self):
|
||||||
definition = {
|
definition = {
|
||||||
@ -388,16 +286,14 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
"rooms": ["!stage:unknown"],
|
"rooms": ["!stage:unknown"],
|
||||||
"not_rooms": ["!piggyshouse:muppets"],
|
"not_rooms": ["!piggyshouse:muppets"],
|
||||||
"types": ["m.room.message", "muppets.kermit.*"],
|
"types": ["m.room.message", "muppets.kermit.*"],
|
||||||
"not_types": ["muppets.misspiggy.*"]
|
"not_types": ["muppets.misspiggy.*"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@kermit:muppets", # yup
|
sender="@kermit:muppets", # yup
|
||||||
type="m.room.message", # yup
|
type="m.room.message", # yup
|
||||||
room_id="!piggyshouse:muppets" # nope
|
room_id="!piggyshouse:muppets", # nope
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_type(self):
|
def test_definition_combined_event_bad_type(self):
|
||||||
definition = {
|
definition = {
|
||||||
@ -406,37 +302,26 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
"rooms": ["!stage:unknown"],
|
"rooms": ["!stage:unknown"],
|
||||||
"not_rooms": ["!piggyshouse:muppets"],
|
"not_rooms": ["!piggyshouse:muppets"],
|
||||||
"types": ["m.room.message", "muppets.kermit.*"],
|
"types": ["m.room.message", "muppets.kermit.*"],
|
||||||
"not_types": ["muppets.misspiggy.*"]
|
"not_types": ["muppets.misspiggy.*"],
|
||||||
}
|
}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@kermit:muppets", # yup
|
sender="@kermit:muppets", # yup
|
||||||
type="muppets.misspiggy.kisses", # nope
|
type="muppets.misspiggy.kisses", # nope
|
||||||
room_id="!stage:unknown" # yup
|
room_id="!stage:unknown", # yup
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
Filter(definition).check(event)
|
|
||||||
)
|
)
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filter_presence_match(self):
|
def test_filter_presence_match(self):
|
||||||
user_filter_json = {
|
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||||
"presence": {
|
|
||||||
"types": ["m.*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
filter_id = yield self.datastore.add_user_filter(
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
user_filter=user_filter_json,
|
|
||||||
)
|
|
||||||
event = MockEvent(
|
|
||||||
sender="@foo:bar",
|
|
||||||
type="m.profile",
|
|
||||||
)
|
)
|
||||||
|
event = MockEvent(sender="@foo:bar", type="m.profile")
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
filter_id=filter_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_presence(events=events)
|
results = user_filter.filter_presence(events=events)
|
||||||
@ -444,15 +329,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filter_presence_no_match(self):
|
def test_filter_presence_no_match(self):
|
||||||
user_filter_json = {
|
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||||
"presence": {
|
|
||||||
"types": ["m.*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filter_id = yield self.datastore.add_user_filter(
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart + "2",
|
user_localpart=user_localpart + "2", user_filter=user_filter_json
|
||||||
user_filter=user_filter_json,
|
|
||||||
)
|
)
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
event_id="$asdasd:localhost",
|
event_id="$asdasd:localhost",
|
||||||
@ -462,8 +342,7 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart + "2",
|
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||||
filter_id=filter_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_presence(events=events)
|
results = user_filter.filter_presence(events=events)
|
||||||
@ -471,27 +350,15 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filter_room_state_match(self):
|
def test_filter_room_state_match(self):
|
||||||
user_filter_json = {
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
"room": {
|
|
||||||
"state": {
|
|
||||||
"types": ["m.*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
filter_id = yield self.datastore.add_user_filter(
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
user_filter=user_filter_json,
|
|
||||||
)
|
|
||||||
event = MockEvent(
|
|
||||||
sender="@foo:bar",
|
|
||||||
type="m.room.topic",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
)
|
||||||
|
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
filter_id=filter_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_room_state(events=events)
|
results = user_filter.filter_room_state(events=events)
|
||||||
@ -499,27 +366,17 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filter_room_state_no_match(self):
|
def test_filter_room_state_no_match(self):
|
||||||
user_filter_json = {
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
"room": {
|
|
||||||
"state": {
|
|
||||||
"types": ["m.*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
filter_id = yield self.datastore.add_user_filter(
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
user_filter=user_filter_json,
|
|
||||||
)
|
)
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
||||||
type="org.matrix.custom.event",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
)
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
filter_id=filter_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_room_state(events)
|
results = user_filter.filter_room_state(events)
|
||||||
@ -543,45 +400,32 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_add_filter(self):
|
def test_add_filter(self):
|
||||||
user_filter_json = {
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
"room": {
|
|
||||||
"state": {
|
|
||||||
"types": ["m.*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filter_id = yield self.filtering.add_user_filter(
|
filter_id = yield self.filtering.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
user_filter=user_filter_json,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(filter_id, 0)
|
self.assertEquals(filter_id, 0)
|
||||||
self.assertEquals(user_filter_json, (
|
self.assertEquals(
|
||||||
yield self.datastore.get_user_filter(
|
user_filter_json,
|
||||||
user_localpart=user_localpart,
|
(
|
||||||
filter_id=0,
|
yield self.datastore.get_user_filter(
|
||||||
)
|
user_localpart=user_localpart, filter_id=0
|
||||||
))
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_filter(self):
|
def test_get_filter(self):
|
||||||
user_filter_json = {
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
"room": {
|
|
||||||
"state": {
|
|
||||||
"types": ["m.*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filter_id = yield self.datastore.add_user_filter(
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
user_filter=user_filter_json,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
filter = yield self.filtering.get_user_filter(
|
filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
filter_id=filter_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(filter.get_filter_json(), user_filter_json)
|
self.assertEquals(filter.get_filter_json(), user_filter_json)
|
||||||
|
@ -4,17 +4,16 @@ from tests import unittest
|
|||||||
|
|
||||||
|
|
||||||
class TestRatelimiter(unittest.TestCase):
|
class TestRatelimiter(unittest.TestCase):
|
||||||
|
|
||||||
def test_allowed(self):
|
def test_allowed(self):
|
||||||
limiter = Ratelimiter()
|
limiter = Ratelimiter()
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.send_message(
|
||||||
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
|
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10., time_allowed)
|
self.assertEquals(10., time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.send_message(
|
||||||
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1,
|
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10., time_allowed)
|
self.assertEquals(10., time_allowed)
|
||||||
@ -28,7 +27,7 @@ class TestRatelimiter(unittest.TestCase):
|
|||||||
def test_pruning(self):
|
def test_pruning(self):
|
||||||
limiter = Ratelimiter()
|
limiter = Ratelimiter()
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.send_message(
|
||||||
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
|
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIn("test_id_1", limiter.message_counts)
|
self.assertIn("test_id_1", limiter.message_counts)
|
||||||
|
@ -24,14 +24,10 @@ from tests import unittest
|
|||||||
|
|
||||||
|
|
||||||
def _regex(regex, exclusive=True):
|
def _regex(regex, exclusive=True):
|
||||||
return {
|
return {"regex": re.compile(regex), "exclusive": exclusive}
|
||||||
"regex": re.compile(regex),
|
|
||||||
"exclusive": exclusive
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceTestCase(unittest.TestCase):
|
class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.service = ApplicationService(
|
self.service = ApplicationService(
|
||||||
id="unique_identifier",
|
id="unique_identifier",
|
||||||
@ -41,8 +37,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
namespaces={
|
namespaces={
|
||||||
ApplicationService.NS_USERS: [],
|
ApplicationService.NS_USERS: [],
|
||||||
ApplicationService.NS_ROOMS: [],
|
ApplicationService.NS_ROOMS: [],
|
||||||
ApplicationService.NS_ALIASES: []
|
ApplicationService.NS_ALIASES: [],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
self.event = Mock(
|
self.event = Mock(
|
||||||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||||
@ -52,25 +48,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_match(self):
|
def test_regex_user_id_prefix_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
self.event.sender = "@irc_foobar:matrix.org"
|
||||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_no_match(self):
|
def test_regex_user_id_prefix_no_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@someone_else:matrix.org"
|
self.event.sender = "@someone_else:matrix.org"
|
||||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_member_is_checked(self):
|
def test_regex_room_member_is_checked(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@someone_else:matrix.org"
|
self.event.sender = "@someone_else:matrix.org"
|
||||||
self.event.type = "m.room.member"
|
self.event.type = "m.room.member"
|
||||||
self.event.state_key = "@irc_foobar:matrix.org"
|
self.event.state_key = "@irc_foobar:matrix.org"
|
||||||
@ -98,60 +88,47 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.store.get_aliases_for_room.return_value = [
|
self.store.get_aliases_for_room.return_value = [
|
||||||
"#irc_foobar:matrix.org", "#athing:matrix.org"
|
"#irc_foobar:matrix.org",
|
||||||
|
"#athing:matrix.org",
|
||||||
]
|
]
|
||||||
self.store.get_users_in_room.return_value = []
|
self.store.get_users_in_room.return_value = []
|
||||||
self.assertTrue((yield self.service.is_interested(
|
self.assertTrue((yield self.service.is_interested(self.event, self.store)))
|
||||||
self.event, self.store
|
|
||||||
)))
|
|
||||||
|
|
||||||
def test_non_exclusive_alias(self):
|
def test_non_exclusive_alias(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org", exclusive=False)
|
_regex("#irc_.*:matrix.org", exclusive=False)
|
||||||
)
|
)
|
||||||
self.assertFalse(self.service.is_exclusive_alias(
|
self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
|
||||||
"#irc_foobar:matrix.org"
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_non_exclusive_room(self):
|
def test_non_exclusive_room(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||||
_regex("!irc_.*:matrix.org", exclusive=False)
|
_regex("!irc_.*:matrix.org", exclusive=False)
|
||||||
)
|
)
|
||||||
self.assertFalse(self.service.is_exclusive_room(
|
self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
|
||||||
"!irc_foobar:matrix.org"
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_non_exclusive_user(self):
|
def test_non_exclusive_user(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||||
_regex("@irc_.*:matrix.org", exclusive=False)
|
_regex("@irc_.*:matrix.org", exclusive=False)
|
||||||
)
|
)
|
||||||
self.assertFalse(self.service.is_exclusive_user(
|
self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
|
||||||
"@irc_foobar:matrix.org"
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_exclusive_alias(self):
|
def test_exclusive_alias(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org", exclusive=True)
|
_regex("#irc_.*:matrix.org", exclusive=True)
|
||||||
)
|
)
|
||||||
self.assertTrue(self.service.is_exclusive_alias(
|
self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
|
||||||
"#irc_foobar:matrix.org"
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_exclusive_user(self):
|
def test_exclusive_user(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||||
_regex("@irc_.*:matrix.org", exclusive=True)
|
_regex("@irc_.*:matrix.org", exclusive=True)
|
||||||
)
|
)
|
||||||
self.assertTrue(self.service.is_exclusive_user(
|
self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
|
||||||
"@irc_foobar:matrix.org"
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_exclusive_room(self):
|
def test_exclusive_room(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||||
_regex("!irc_.*:matrix.org", exclusive=True)
|
_regex("!irc_.*:matrix.org", exclusive=True)
|
||||||
)
|
)
|
||||||
self.assertTrue(self.service.is_exclusive_room(
|
self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
|
||||||
"!irc_foobar:matrix.org"
|
|
||||||
))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_alias_no_match(self):
|
def test_regex_alias_no_match(self):
|
||||||
@ -159,47 +136,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.store.get_aliases_for_room.return_value = [
|
self.store.get_aliases_for_room.return_value = [
|
||||||
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
|
"#xmpp_foobar:matrix.org",
|
||||||
|
"#athing:matrix.org",
|
||||||
]
|
]
|
||||||
self.store.get_users_in_room.return_value = []
|
self.store.get_users_in_room.return_value = []
|
||||||
self.assertFalse((yield self.service.is_interested(
|
self.assertFalse((yield self.service.is_interested(self.event, self.store)))
|
||||||
self.event, self.store
|
|
||||||
)))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_multiple_matches(self):
|
def test_regex_multiple_matches(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
self.event.sender = "@irc_foobar:matrix.org"
|
||||||
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
||||||
self.store.get_users_in_room.return_value = []
|
self.store.get_users_in_room.return_value = []
|
||||||
self.assertTrue((yield self.service.is_interested(
|
self.assertTrue((yield self.service.is_interested(self.event, self.store)))
|
||||||
self.event, self.store
|
|
||||||
)))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_interested_in_self(self):
|
def test_interested_in_self(self):
|
||||||
# make sure invites get through
|
# make sure invites get through
|
||||||
self.service.sender = "@appservice:name"
|
self.service.sender = "@appservice:name"
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.type = "m.room.member"
|
self.event.type = "m.room.member"
|
||||||
self.event.content = {
|
self.event.content = {"membership": "invite"}
|
||||||
"membership": "invite"
|
|
||||||
}
|
|
||||||
self.event.state_key = self.service.sender
|
self.event.state_key = self.service.sender
|
||||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_member_list_match(self):
|
def test_member_list_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.store.get_users_in_room.return_value = [
|
self.store.get_users_in_room.return_value = [
|
||||||
"@alice:here",
|
"@alice:here",
|
||||||
"@irc_fo:here", # AS user
|
"@irc_fo:here", # AS user
|
||||||
@ -208,6 +174,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
self.store.get_aliases_for_room.return_value = []
|
self.store.get_aliases_for_room.return_value = []
|
||||||
|
|
||||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||||
self.assertTrue((yield self.service.is_interested(
|
self.assertTrue(
|
||||||
event=self.event, store=self.store
|
(yield self.service.is_interested(event=self.event, store=self.store))
|
||||||
)))
|
)
|
||||||
|
@ -30,7 +30,6 @@ from ..utils import MockClock
|
|||||||
|
|
||||||
|
|
||||||
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.clock = MockClock()
|
self.clock = MockClock()
|
||||||
self.store = Mock()
|
self.store = Mock()
|
||||||
@ -38,8 +37,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
self.recoverer = Mock()
|
self.recoverer = Mock()
|
||||||
self.recoverer_fn = Mock(return_value=self.recoverer)
|
self.recoverer_fn = Mock(return_value=self.recoverer)
|
||||||
self.txnctrl = _TransactionController(
|
self.txnctrl = _TransactionController(
|
||||||
clock=self.clock, store=self.store, as_api=self.as_api,
|
clock=self.clock,
|
||||||
recoverer_fn=self.recoverer_fn
|
store=self.store,
|
||||||
|
as_api=self.as_api,
|
||||||
|
recoverer_fn=self.recoverer_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_single_service_up_txn_sent(self):
|
def test_single_service_up_txn_sent(self):
|
||||||
@ -54,9 +55,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
return_value=defer.succeed(ApplicationServiceState.UP)
|
return_value=defer.succeed(ApplicationServiceState.UP)
|
||||||
)
|
)
|
||||||
txn.send = Mock(return_value=defer.succeed(True))
|
txn.send = Mock(return_value=defer.succeed(True))
|
||||||
self.store.create_appservice_txn = Mock(
|
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||||
return_value=defer.succeed(txn)
|
|
||||||
)
|
|
||||||
|
|
||||||
# actual call
|
# actual call
|
||||||
self.txnctrl.send(service, events)
|
self.txnctrl.send(service, events)
|
||||||
@ -77,9 +76,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
self.store.get_appservice_state = Mock(
|
self.store.get_appservice_state = Mock(
|
||||||
return_value=defer.succeed(ApplicationServiceState.DOWN)
|
return_value=defer.succeed(ApplicationServiceState.DOWN)
|
||||||
)
|
)
|
||||||
self.store.create_appservice_txn = Mock(
|
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||||
return_value=defer.succeed(txn)
|
|
||||||
)
|
|
||||||
|
|
||||||
# actual call
|
# actual call
|
||||||
self.txnctrl.send(service, events)
|
self.txnctrl.send(service, events)
|
||||||
@ -104,9 +101,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
|
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
|
||||||
txn.send = Mock(return_value=defer.succeed(False)) # fails to send
|
txn.send = Mock(return_value=defer.succeed(False)) # fails to send
|
||||||
self.store.create_appservice_txn = Mock(
|
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||||
return_value=defer.succeed(txn)
|
|
||||||
)
|
|
||||||
|
|
||||||
# actual call
|
# actual call
|
||||||
self.txnctrl.send(service, events)
|
self.txnctrl.send(service, events)
|
||||||
@ -124,7 +119,6 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.clock = MockClock()
|
self.clock = MockClock()
|
||||||
self.as_api = Mock()
|
self.as_api = Mock()
|
||||||
@ -146,6 +140,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def take_txn(*args, **kwargs):
|
def take_txn(*args, **kwargs):
|
||||||
return defer.succeed(txns.pop(0))
|
return defer.succeed(txns.pop(0))
|
||||||
|
|
||||||
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
|
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
|
||||||
|
|
||||||
self.recoverer.recover()
|
self.recoverer.recover()
|
||||||
@ -171,6 +166,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||||||
return defer.succeed(txns.pop(0))
|
return defer.succeed(txns.pop(0))
|
||||||
else:
|
else:
|
||||||
return defer.succeed(txn)
|
return defer.succeed(txn)
|
||||||
|
|
||||||
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
|
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
|
||||||
|
|
||||||
self.recoverer.recover()
|
self.recoverer.recover()
|
||||||
@ -197,7 +193,6 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.txn_ctrl = Mock()
|
self.txn_ctrl = Mock()
|
||||||
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
||||||
@ -211,9 +206,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_send_single_event_with_queue(self):
|
def test_send_single_event_with_queue(self):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
self.txn_ctrl.send = Mock(
|
self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
|
||||||
side_effect=lambda x, y: make_deferred_yieldable(d),
|
|
||||||
)
|
|
||||||
service = Mock(id=4)
|
service = Mock(id=4)
|
||||||
event = Mock(event_id="first")
|
event = Mock(event_id="first")
|
||||||
event2 = Mock(event_id="second")
|
event2 = Mock(event_id="second")
|
||||||
@ -247,6 +240,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def do_send(x, y):
|
def do_send(x, y):
|
||||||
return make_deferred_yieldable(send_return_list.pop(0))
|
return make_deferred_yieldable(send_return_list.pop(0))
|
||||||
|
|
||||||
self.txn_ctrl.send = Mock(side_effect=do_send)
|
self.txn_ctrl.send = Mock(side_effect=do_send)
|
||||||
|
|
||||||
# send events for different ASes and make sure they are sent
|
# send events for different ASes and make sure they are sent
|
||||||
|
@ -24,7 +24,6 @@ from tests import unittest
|
|||||||
|
|
||||||
|
|
||||||
class ConfigGenerationTestCase(unittest.TestCase):
|
class ConfigGenerationTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.dir = tempfile.mkdtemp()
|
self.dir = tempfile.mkdtemp()
|
||||||
self.file = os.path.join(self.dir, "homeserver.yaml")
|
self.file = os.path.join(self.dir, "homeserver.yaml")
|
||||||
@ -33,23 +32,30 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||||||
shutil.rmtree(self.dir)
|
shutil.rmtree(self.dir)
|
||||||
|
|
||||||
def test_generate_config_generates_files(self):
|
def test_generate_config_generates_files(self):
|
||||||
HomeServerConfig.load_or_generate_config("", [
|
HomeServerConfig.load_or_generate_config(
|
||||||
"--generate-config",
|
"",
|
||||||
"-c", self.file,
|
[
|
||||||
"--report-stats=yes",
|
"--generate-config",
|
||||||
"-H", "lemurs.win"
|
"-c",
|
||||||
])
|
self.file,
|
||||||
|
"--report-stats=yes",
|
||||||
|
"-H",
|
||||||
|
"lemurs.win",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
set([
|
set(
|
||||||
"homeserver.yaml",
|
[
|
||||||
"lemurs.win.log.config",
|
"homeserver.yaml",
|
||||||
"lemurs.win.signing.key",
|
"lemurs.win.log.config",
|
||||||
"lemurs.win.tls.crt",
|
"lemurs.win.signing.key",
|
||||||
"lemurs.win.tls.dh",
|
"lemurs.win.tls.crt",
|
||||||
"lemurs.win.tls.key",
|
"lemurs.win.tls.dh",
|
||||||
]),
|
"lemurs.win.tls.key",
|
||||||
set(os.listdir(self.dir))
|
]
|
||||||
|
),
|
||||||
|
set(os.listdir(self.dir)),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assert_log_filename_is(
|
self.assert_log_filename_is(
|
||||||
|
@ -24,7 +24,6 @@ from tests import unittest
|
|||||||
|
|
||||||
|
|
||||||
class ConfigLoadingTestCase(unittest.TestCase):
|
class ConfigLoadingTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.dir = tempfile.mkdtemp()
|
self.dir = tempfile.mkdtemp()
|
||||||
print(self.dir)
|
print(self.dir)
|
||||||
@ -43,15 +42,14 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||||||
def test_generates_and_loads_macaroon_secret_key(self):
|
def test_generates_and_loads_macaroon_secret_key(self):
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
|
|
||||||
with open(self.file,
|
with open(self.file, "r") as f:
|
||||||
"r") as f:
|
|
||||||
raw = yaml.load(f)
|
raw = yaml.load(f)
|
||||||
self.assertIn("macaroon_secret_key", raw)
|
self.assertIn("macaroon_secret_key", raw)
|
||||||
|
|
||||||
config = HomeServerConfig.load_config("", ["-c", self.file])
|
config = HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
hasattr(config, "macaroon_secret_key"),
|
hasattr(config, "macaroon_secret_key"),
|
||||||
"Want config to have attr macaroon_secret_key"
|
"Want config to have attr macaroon_secret_key",
|
||||||
)
|
)
|
||||||
if len(config.macaroon_secret_key) < 5:
|
if len(config.macaroon_secret_key) < 5:
|
||||||
self.fail(
|
self.fail(
|
||||||
@ -62,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||||||
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
hasattr(config, "macaroon_secret_key"),
|
hasattr(config, "macaroon_secret_key"),
|
||||||
"Want config to have attr macaroon_secret_key"
|
"Want config to have attr macaroon_secret_key",
|
||||||
)
|
)
|
||||||
if len(config.macaroon_secret_key) < 5:
|
if len(config.macaroon_secret_key) < 5:
|
||||||
self.fail(
|
self.fail(
|
||||||
@ -80,10 +78,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_disable_registration(self):
|
def test_disable_registration(self):
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
self.add_lines_to_config([
|
self.add_lines_to_config(
|
||||||
"enable_registration: true",
|
["enable_registration: true", "disable_registration: true"]
|
||||||
"disable_registration: true",
|
)
|
||||||
])
|
|
||||||
# Check that disable_registration clobbers enable_registration.
|
# Check that disable_registration clobbers enable_registration.
|
||||||
config = HomeServerConfig.load_config("", ["-c", self.file])
|
config = HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
self.assertFalse(config.enable_registration)
|
self.assertFalse(config.enable_registration)
|
||||||
@ -92,18 +89,23 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||||||
self.assertFalse(config.enable_registration)
|
self.assertFalse(config.enable_registration)
|
||||||
|
|
||||||
# Check that either config value is clobbered by the command line.
|
# Check that either config value is clobbered by the command line.
|
||||||
config = HomeServerConfig.load_or_generate_config("", [
|
config = HomeServerConfig.load_or_generate_config(
|
||||||
"-c", self.file, "--enable-registration"
|
"", ["-c", self.file, "--enable-registration"]
|
||||||
])
|
)
|
||||||
self.assertTrue(config.enable_registration)
|
self.assertTrue(config.enable_registration)
|
||||||
|
|
||||||
def generate_config(self):
|
def generate_config(self):
|
||||||
HomeServerConfig.load_or_generate_config("", [
|
HomeServerConfig.load_or_generate_config(
|
||||||
"--generate-config",
|
"",
|
||||||
"-c", self.file,
|
[
|
||||||
"--report-stats=yes",
|
"--generate-config",
|
||||||
"-H", "lemurs.win"
|
"-c",
|
||||||
])
|
self.file,
|
||||||
|
"--report-stats=yes",
|
||||||
|
"-H",
|
||||||
|
"lemurs.win",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def generate_config_and_remove_lines_containing(self, needle):
|
def generate_config_and_remove_lines_containing(self, needle):
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
|
@ -24,9 +24,7 @@ from tests import unittest
|
|||||||
|
|
||||||
# Perform these tests using given secret key so we get entirely deterministic
|
# Perform these tests using given secret key so we get entirely deterministic
|
||||||
# signatures output that we can test against.
|
# signatures output that we can test against.
|
||||||
SIGNING_KEY_SEED = decode_base64(
|
SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
|
||||||
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
|
|
||||||
)
|
|
||||||
|
|
||||||
KEY_ALG = "ed25519"
|
KEY_ALG = "ed25519"
|
||||||
KEY_VER = 1
|
KEY_VER = 1
|
||||||
@ -36,7 +34,6 @@ HOSTNAME = "domain"
|
|||||||
|
|
||||||
|
|
||||||
class EventSigningTestCase(unittest.TestCase):
|
class EventSigningTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
|
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
|
||||||
self.signing_key.alg = KEY_ALG
|
self.signing_key.alg = KEY_ALG
|
||||||
@ -51,7 +48,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||||||
'signatures': {},
|
'signatures': {},
|
||||||
'type': "X",
|
'type': "X",
|
||||||
'unsigned': {'age_ts': 1000000},
|
'unsigned': {'age_ts': 1000000},
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
|
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
|
||||||
@ -61,8 +58,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||||||
self.assertTrue(hasattr(event, 'hashes'))
|
self.assertTrue(hasattr(event, 'hashes'))
|
||||||
self.assertIn('sha256', event.hashes)
|
self.assertIn('sha256', event.hashes)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
event.hashes['sha256'],
|
event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
|
||||||
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(hasattr(event, 'signatures'))
|
self.assertTrue(hasattr(event, 'signatures'))
|
||||||
@ -77,9 +73,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||||||
def test_sign_message(self):
|
def test_sign_message(self):
|
||||||
builder = EventBuilder(
|
builder = EventBuilder(
|
||||||
{
|
{
|
||||||
'content': {
|
'content': {'body': "Here is the message content"},
|
||||||
'body': "Here is the message content",
|
|
||||||
},
|
|
||||||
'event_id': "$0:domain",
|
'event_id': "$0:domain",
|
||||||
'origin': "domain",
|
'origin': "domain",
|
||||||
'origin_server_ts': 1000000,
|
'origin_server_ts': 1000000,
|
||||||
@ -98,8 +92,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||||||
self.assertTrue(hasattr(event, 'hashes'))
|
self.assertTrue(hasattr(event, 'hashes'))
|
||||||
self.assertIn('sha256', event.hashes)
|
self.assertIn('sha256', event.hashes)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
event.hashes['sha256'],
|
event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
|
||||||
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(hasattr(event, 'signatures'))
|
self.assertTrue(hasattr(event, 'signatures'))
|
||||||
@ -108,5 +101,5 @@ class EventSigningTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
event.signatures[HOSTNAME][KEY_NAME],
|
event.signatures[HOSTNAME][KEY_NAME],
|
||||||
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
|
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
|
||||||
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
|
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA",
|
||||||
)
|
)
|
||||||
|
@ -36,9 +36,7 @@ class MockPerspectiveServer(object):
|
|||||||
|
|
||||||
def get_verify_keys(self):
|
def get_verify_keys(self):
|
||||||
vk = signedjson.key.get_verify_key(self.key)
|
vk = signedjson.key.get_verify_key(self.key)
|
||||||
return {
|
return {"%s:%s" % (vk.alg, vk.version): vk}
|
||||||
"%s:%s" % (vk.alg, vk.version): vk,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_signed_key(self, server_name, verify_key):
|
def get_signed_key(self, server_name, verify_key):
|
||||||
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||||
@ -47,10 +45,8 @@ class MockPerspectiveServer(object):
|
|||||||
"old_verify_keys": {},
|
"old_verify_keys": {},
|
||||||
"valid_until_ts": time.time() * 1000 + 3600,
|
"valid_until_ts": time.time() * 1000 + 3600,
|
||||||
"verify_keys": {
|
"verify_keys": {
|
||||||
key_id: {
|
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
|
||||||
"key": signedjson.key.encode_verify_key_base64(verify_key)
|
},
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
signedjson.sign.sign_json(res, self.server_name, self.key)
|
signedjson.sign.sign_json(res, self.server_name, self.key)
|
||||||
return res
|
return res
|
||||||
@ -62,18 +58,14 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
self.mock_perspective_server = MockPerspectiveServer()
|
self.mock_perspective_server = MockPerspectiveServer()
|
||||||
self.http_client = Mock()
|
self.http_client = Mock()
|
||||||
self.hs = yield utils.setup_test_homeserver(
|
self.hs = yield utils.setup_test_homeserver(
|
||||||
handlers=None,
|
self.addCleanup, handlers=None, http_client=self.http_client
|
||||||
http_client=self.http_client,
|
|
||||||
)
|
)
|
||||||
self.hs.config.perspectives = {
|
keys = self.mock_perspective_server.get_verify_keys()
|
||||||
self.mock_perspective_server.server_name:
|
self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
|
||||||
self.mock_perspective_server.get_verify_keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
def check_context(self, _, expected):
|
def check_context(self, _, expected):
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
getattr(LoggingContext.current_context(), "request", None),
|
getattr(LoggingContext.current_context(), "request", None), expected
|
||||||
expected
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -89,8 +81,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
context_one.request = "one"
|
context_one.request = "one"
|
||||||
|
|
||||||
wait_1_deferred = kr.wait_for_previous_lookups(
|
wait_1_deferred = kr.wait_for_previous_lookups(
|
||||||
["server1"],
|
["server1"], {"server1": lookup_1_deferred}
|
||||||
{"server1": lookup_1_deferred},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# there were no previous lookups, so the deferred should be ready
|
# there were no previous lookups, so the deferred should be ready
|
||||||
@ -105,8 +96,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
# set off another wait. It should block because the first lookup
|
# set off another wait. It should block because the first lookup
|
||||||
# hasn't yet completed.
|
# hasn't yet completed.
|
||||||
wait_2_deferred = kr.wait_for_previous_lookups(
|
wait_2_deferred = kr.wait_for_previous_lookups(
|
||||||
["server1"],
|
["server1"], {"server1": lookup_2_deferred}
|
||||||
{"server1": lookup_2_deferred},
|
|
||||||
)
|
)
|
||||||
self.assertFalse(wait_2_deferred.called)
|
self.assertFalse(wait_2_deferred.called)
|
||||||
# ... so we should have reset the LoggingContext.
|
# ... so we should have reset the LoggingContext.
|
||||||
@ -132,21 +122,19 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
persp_resp = {
|
persp_resp = {
|
||||||
"server_keys": [
|
"server_keys": [
|
||||||
self.mock_perspective_server.get_signed_key(
|
self.mock_perspective_server.get_signed_key(
|
||||||
"server10",
|
"server10", signedjson.key.get_verify_key(key1)
|
||||||
signedjson.key.get_verify_key(key1)
|
)
|
||||||
),
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
persp_deferred = defer.Deferred()
|
persp_deferred = defer.Deferred()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_perspectives(**kwargs):
|
def get_perspectives(**kwargs):
|
||||||
self.assertEquals(
|
self.assertEquals(LoggingContext.current_context().request, "11")
|
||||||
LoggingContext.current_context().request, "11",
|
|
||||||
)
|
|
||||||
with logcontext.PreserveLoggingContext():
|
with logcontext.PreserveLoggingContext():
|
||||||
yield persp_deferred
|
yield persp_deferred
|
||||||
defer.returnValue(persp_resp)
|
defer.returnValue(persp_resp)
|
||||||
|
|
||||||
self.http_client.post_json.side_effect = get_perspectives
|
self.http_client.post_json.side_effect = get_perspectives
|
||||||
|
|
||||||
with LoggingContext("11") as context_11:
|
with LoggingContext("11") as context_11:
|
||||||
@ -154,9 +142,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# start off a first set of lookups
|
# start off a first set of lookups
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1),
|
[("server10", json1), ("server11", {})]
|
||||||
("server11", {})
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# the unsigned json should be rejected pretty quickly
|
# the unsigned json should be rejected pretty quickly
|
||||||
@ -172,7 +158,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# wait a tick for it to send the request to the perspectives server
|
# wait a tick for it to send the request to the perspectives server
|
||||||
# (it first tries the datastore)
|
# (it first tries the datastore)
|
||||||
yield clock.sleep(1) # XXX find out why this takes so long!
|
yield clock.sleep(1) # XXX find out why this takes so long!
|
||||||
self.http_client.post_json.assert_called_once()
|
self.http_client.post_json.assert_called_once()
|
||||||
|
|
||||||
self.assertIs(LoggingContext.current_context(), context_11)
|
self.assertIs(LoggingContext.current_context(), context_11)
|
||||||
@ -186,7 +172,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
self.http_client.post_json.return_value = defer.Deferred()
|
self.http_client.post_json.return_value = defer.Deferred()
|
||||||
|
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1)],
|
[("server10", json1)]
|
||||||
)
|
)
|
||||||
yield clock.sleep(1)
|
yield clock.sleep(1)
|
||||||
self.http_client.post_json.assert_not_called()
|
self.http_client.post_json.assert_not_called()
|
||||||
@ -207,8 +193,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
yield self.hs.datastore.store_server_verify_key(
|
yield self.hs.datastore.store_server_verify_key(
|
||||||
"server9", "", time.time() * 1000,
|
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
|
||||||
signedjson.key.get_verify_key(key1),
|
|
||||||
)
|
)
|
||||||
json1 = {}
|
json1 = {}
|
||||||
signedjson.sign.sign_json(json1, "server9", key1)
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
@ -31,25 +31,20 @@ def MockEvent(**kwargs):
|
|||||||
class PruneEventTestCase(unittest.TestCase):
|
class PruneEventTestCase(unittest.TestCase):
|
||||||
""" Asserts that a new event constructed with `evdict` will look like
|
""" Asserts that a new event constructed with `evdict` will look like
|
||||||
`matchdict` when it is redacted. """
|
`matchdict` when it is redacted. """
|
||||||
|
|
||||||
def run_test(self, evdict, matchdict):
|
def run_test(self, evdict, matchdict):
|
||||||
self.assertEquals(
|
self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict)
|
||||||
prune_event(FrozenEvent(evdict)).get_dict(),
|
|
||||||
matchdict
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_minimal(self):
|
def test_minimal(self):
|
||||||
self.run_test(
|
self.run_test(
|
||||||
{
|
{'type': 'A', 'event_id': '$test:domain'},
|
||||||
'type': 'A',
|
|
||||||
'event_id': '$test:domain',
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
'type': 'A',
|
'type': 'A',
|
||||||
'event_id': '$test:domain',
|
'event_id': '$test:domain',
|
||||||
'content': {},
|
'content': {},
|
||||||
'signatures': {},
|
'signatures': {},
|
||||||
'unsigned': {},
|
'unsigned': {},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_basic_keys(self):
|
def test_basic_keys(self):
|
||||||
@ -70,23 +65,19 @@ class PruneEventTestCase(unittest.TestCase):
|
|||||||
'content': {},
|
'content': {},
|
||||||
'signatures': {},
|
'signatures': {},
|
||||||
'unsigned': {},
|
'unsigned': {},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unsigned_age_ts(self):
|
def test_unsigned_age_ts(self):
|
||||||
self.run_test(
|
self.run_test(
|
||||||
{
|
{'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}},
|
||||||
'type': 'B',
|
|
||||||
'event_id': '$test:domain',
|
|
||||||
'unsigned': {'age_ts': 20},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
'type': 'B',
|
'type': 'B',
|
||||||
'event_id': '$test:domain',
|
'event_id': '$test:domain',
|
||||||
'content': {},
|
'content': {},
|
||||||
'signatures': {},
|
'signatures': {},
|
||||||
'unsigned': {'age_ts': 20},
|
'unsigned': {'age_ts': 20},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.run_test(
|
self.run_test(
|
||||||
@ -101,23 +92,19 @@ class PruneEventTestCase(unittest.TestCase):
|
|||||||
'content': {},
|
'content': {},
|
||||||
'signatures': {},
|
'signatures': {},
|
||||||
'unsigned': {},
|
'unsigned': {},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_content(self):
|
def test_content(self):
|
||||||
self.run_test(
|
self.run_test(
|
||||||
{
|
{'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}},
|
||||||
'type': 'C',
|
|
||||||
'event_id': '$test:domain',
|
|
||||||
'content': {'things': 'here'},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
'type': 'C',
|
'type': 'C',
|
||||||
'event_id': '$test:domain',
|
'event_id': '$test:domain',
|
||||||
'content': {},
|
'content': {},
|
||||||
'signatures': {},
|
'signatures': {},
|
||||||
'unsigned': {},
|
'unsigned': {},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.run_test(
|
self.run_test(
|
||||||
@ -132,27 +119,20 @@ class PruneEventTestCase(unittest.TestCase):
|
|||||||
'content': {'creator': '@2:domain'},
|
'content': {'creator': '@2:domain'},
|
||||||
'signatures': {},
|
'signatures': {},
|
||||||
'unsigned': {},
|
'unsigned': {},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SerializeEventTestCase(unittest.TestCase):
|
class SerializeEventTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def serialize(self, ev, fields):
|
def serialize(self, ev, fields):
|
||||||
return serialize_event(ev, 1479807801915, only_event_fields=fields)
|
return serialize_event(ev, 1479807801915, only_event_fields=fields)
|
||||||
|
|
||||||
def test_event_fields_works_with_keys(self):
|
def test_event_fields_works_with_keys(self):
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
self.serialize(
|
self.serialize(
|
||||||
MockEvent(
|
MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
|
||||||
sender="@alice:localhost",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
),
|
|
||||||
["room_id"]
|
|
||||||
),
|
),
|
||||||
{
|
{"room_id": "!foo:bar"},
|
||||||
"room_id": "!foo:bar",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_works_with_nested_keys(self):
|
def test_event_fields_works_with_nested_keys(self):
|
||||||
@ -161,17 +141,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
MockEvent(
|
MockEvent(
|
||||||
sender="@alice:localhost",
|
sender="@alice:localhost",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={"body": "A message"},
|
||||||
"body": "A message",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
["content.body"]
|
["content.body"],
|
||||||
),
|
),
|
||||||
{
|
{"content": {"body": "A message"}},
|
||||||
"content": {
|
|
||||||
"body": "A message",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_works_with_dot_keys(self):
|
def test_event_fields_works_with_dot_keys(self):
|
||||||
@ -180,17 +154,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
MockEvent(
|
MockEvent(
|
||||||
sender="@alice:localhost",
|
sender="@alice:localhost",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={"key.with.dots": {}},
|
||||||
"key.with.dots": {},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
["content.key\.with\.dots"]
|
["content.key\.with\.dots"],
|
||||||
),
|
),
|
||||||
{
|
{"content": {"key.with.dots": {}}},
|
||||||
"content": {
|
|
||||||
"key.with.dots": {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_works_with_nested_dot_keys(self):
|
def test_event_fields_works_with_nested_dot_keys(self):
|
||||||
@ -201,21 +169,12 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={
|
||||||
"not_me": 1,
|
"not_me": 1,
|
||||||
"nested.dot.key": {
|
"nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
|
||||||
"leaf.key": 42,
|
|
||||||
"not_me_either": 1,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
["content.nested\.dot\.key.leaf\.key"]
|
["content.nested\.dot\.key.leaf\.key"],
|
||||||
),
|
),
|
||||||
{
|
{"content": {"nested.dot.key": {"leaf.key": 42}}},
|
||||||
"content": {
|
|
||||||
"nested.dot.key": {
|
|
||||||
"leaf.key": 42,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_nops_with_unknown_keys(self):
|
def test_event_fields_nops_with_unknown_keys(self):
|
||||||
@ -224,17 +183,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
MockEvent(
|
MockEvent(
|
||||||
sender="@alice:localhost",
|
sender="@alice:localhost",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={"foo": "bar"},
|
||||||
"foo": "bar",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
["content.foo", "content.notexists"]
|
["content.foo", "content.notexists"],
|
||||||
),
|
),
|
||||||
{
|
{"content": {"foo": "bar"}},
|
||||||
"content": {
|
|
||||||
"foo": "bar",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_nops_with_non_dict_keys(self):
|
def test_event_fields_nops_with_non_dict_keys(self):
|
||||||
@ -243,13 +196,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
MockEvent(
|
MockEvent(
|
||||||
sender="@alice:localhost",
|
sender="@alice:localhost",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={"foo": ["I", "am", "an", "array"]},
|
||||||
"foo": ["I", "am", "an", "array"],
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
["content.foo.am"]
|
["content.foo.am"],
|
||||||
),
|
),
|
||||||
{}
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_nops_with_array_keys(self):
|
def test_event_fields_nops_with_array_keys(self):
|
||||||
@ -258,13 +209,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
MockEvent(
|
MockEvent(
|
||||||
sender="@alice:localhost",
|
sender="@alice:localhost",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={"foo": ["I", "am", "an", "array"]},
|
||||||
"foo": ["I", "am", "an", "array"],
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
["content.foo.1"]
|
["content.foo.1"],
|
||||||
),
|
),
|
||||||
{}
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_all_fields_if_empty(self):
|
def test_event_fields_all_fields_if_empty(self):
|
||||||
@ -274,31 +223,21 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||||||
type="foo",
|
type="foo",
|
||||||
event_id="test",
|
event_id="test",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
content={
|
content={"foo": "bar"},
|
||||||
"foo": "bar",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
[]
|
[],
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
"type": "foo",
|
"type": "foo",
|
||||||
"event_id": "test",
|
"event_id": "test",
|
||||||
"room_id": "!foo:bar",
|
"room_id": "!foo:bar",
|
||||||
"content": {
|
"content": {"foo": "bar"},
|
||||||
"foo": "bar",
|
"unsigned": {},
|
||||||
},
|
},
|
||||||
"unsigned": {}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_fields_fail_if_fields_not_str(self):
|
def test_event_fields_fail_if_fields_not_str(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
self.serialize(
|
self.serialize(
|
||||||
MockEvent(
|
MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
|
||||||
room_id="!foo:bar",
|
|
||||||
content={
|
|
||||||
"foo": "bar",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
["room_id", 4]
|
|
||||||
)
|
)
|
||||||
|
@ -23,10 +23,7 @@ from tests import unittest
|
|||||||
@unittest.DEBUG
|
@unittest.DEBUG
|
||||||
class ServerACLsTestCase(unittest.TestCase):
|
class ServerACLsTestCase(unittest.TestCase):
|
||||||
def test_blacklisted_server(self):
|
def test_blacklisted_server(self):
|
||||||
e = _create_acl_event({
|
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
||||||
"allow": ["*"],
|
|
||||||
"deny": ["evil.com"],
|
|
||||||
})
|
|
||||||
logging.info("ACL event: %s", e.content)
|
logging.info("ACL event: %s", e.content)
|
||||||
|
|
||||||
self.assertFalse(server_matches_acl_event("evil.com", e))
|
self.assertFalse(server_matches_acl_event("evil.com", e))
|
||||||
@ -36,10 +33,7 @@ class ServerACLsTestCase(unittest.TestCase):
|
|||||||
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
|
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
|
||||||
|
|
||||||
def test_block_ip_literals(self):
|
def test_block_ip_literals(self):
|
||||||
e = _create_acl_event({
|
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
|
||||||
"allow_ip_literals": False,
|
|
||||||
"allow": ["*"],
|
|
||||||
})
|
|
||||||
logging.info("ACL event: %s", e.content)
|
logging.info("ACL event: %s", e.content)
|
||||||
|
|
||||||
self.assertFalse(server_matches_acl_event("1.2.3.4", e))
|
self.assertFalse(server_matches_acl_event("1.2.3.4", e))
|
||||||
@ -49,10 +43,12 @@ class ServerACLsTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def _create_acl_event(content):
|
def _create_acl_event(content):
|
||||||
return FrozenEvent({
|
return FrozenEvent(
|
||||||
"room_id": "!a:b",
|
{
|
||||||
"event_id": "$a:b",
|
"room_id": "!a:b",
|
||||||
"type": "m.room.server_acls",
|
"event_id": "$a:b",
|
||||||
"sender": "@a:b",
|
"type": "m.room.server_acls",
|
||||||
"content": content
|
"sender": "@a:b",
|
||||||
})
|
"content": content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -45,20 +45,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
services = [
|
services = [
|
||||||
self._mkservice(is_interested=False),
|
self._mkservice(is_interested=False),
|
||||||
interested_service,
|
interested_service,
|
||||||
self._mkservice(is_interested=False)
|
self._mkservice(is_interested=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.mock_store.get_app_services = Mock(return_value=services)
|
self.mock_store.get_app_services = Mock(return_value=services)
|
||||||
self.mock_store.get_user_by_id = Mock(return_value=[])
|
self.mock_store.get_user_by_id = Mock(return_value=[])
|
||||||
|
|
||||||
event = Mock(
|
event = Mock(
|
||||||
sender="@someone:anywhere",
|
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
|
||||||
type="m.room.message",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
)
|
||||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||||
(0, [event]),
|
(0, [event]),
|
||||||
(0, [])
|
(0, []),
|
||||||
]
|
]
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
yield self.handler.notify_interested_services(0)
|
yield self.handler.notify_interested_services(0)
|
||||||
@ -74,21 +72,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
self.mock_store.get_app_services = Mock(return_value=services)
|
self.mock_store.get_app_services = Mock(return_value=services)
|
||||||
self.mock_store.get_user_by_id = Mock(return_value=None)
|
self.mock_store.get_user_by_id = Mock(return_value=None)
|
||||||
|
|
||||||
event = Mock(
|
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
|
||||||
sender=user_id,
|
|
||||||
type="m.room.message",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
self.mock_as_api.query_user = Mock()
|
self.mock_as_api.query_user = Mock()
|
||||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||||
(0, [event]),
|
(0, [event]),
|
||||||
(0, [])
|
(0, []),
|
||||||
]
|
]
|
||||||
yield self.handler.notify_interested_services(0)
|
yield self.handler.notify_interested_services(0)
|
||||||
self.mock_as_api.query_user.assert_called_once_with(
|
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
||||||
services[0], user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_query_user_exists_known_user(self):
|
def test_query_user_exists_known_user(self):
|
||||||
@ -96,25 +88,19 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
services = [self._mkservice(is_interested=True)]
|
services = [self._mkservice(is_interested=True)]
|
||||||
services[0].is_interested_in_user = Mock(return_value=True)
|
services[0].is_interested_in_user = Mock(return_value=True)
|
||||||
self.mock_store.get_app_services = Mock(return_value=services)
|
self.mock_store.get_app_services = Mock(return_value=services)
|
||||||
self.mock_store.get_user_by_id = Mock(return_value={
|
self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
|
||||||
"name": user_id
|
|
||||||
})
|
|
||||||
|
|
||||||
event = Mock(
|
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
|
||||||
sender=user_id,
|
|
||||||
type="m.room.message",
|
|
||||||
room_id="!foo:bar"
|
|
||||||
)
|
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
self.mock_as_api.query_user = Mock()
|
self.mock_as_api.query_user = Mock()
|
||||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||||
(0, [event]),
|
(0, [event]),
|
||||||
(0, [])
|
(0, []),
|
||||||
]
|
]
|
||||||
yield self.handler.notify_interested_services(0)
|
yield self.handler.notify_interested_services(0)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.mock_as_api.query_user.called,
|
self.mock_as_api.query_user.called,
|
||||||
"query_user called when it shouldn't have been."
|
"query_user called when it shouldn't have been.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -129,7 +115,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
services = [
|
services = [
|
||||||
self._mkservice_alias(is_interested_in_alias=False),
|
self._mkservice_alias(is_interested_in_alias=False),
|
||||||
interested_service,
|
interested_service,
|
||||||
self._mkservice_alias(is_interested_in_alias=False)
|
self._mkservice_alias(is_interested_in_alias=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.mock_store.get_app_services = Mock(return_value=services)
|
self.mock_store.get_app_services = Mock(return_value=services)
|
||||||
@ -140,8 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
result = yield self.handler.query_room_alias_exists(room_alias)
|
result = yield self.handler.query_room_alias_exists(room_alias)
|
||||||
|
|
||||||
self.mock_as_api.query_alias.assert_called_once_with(
|
self.mock_as_api.query_alias.assert_called_once_with(
|
||||||
interested_service,
|
interested_service, room_alias_str
|
||||||
room_alias_str
|
|
||||||
)
|
)
|
||||||
self.assertEquals(result.room_id, room_id)
|
self.assertEquals(result.room_id, room_id)
|
||||||
self.assertEquals(result.servers, servers)
|
self.assertEquals(result.servers, servers)
|
||||||
|
@ -35,7 +35,7 @@ class AuthHandlers(object):
|
|||||||
class AuthTestCase(unittest.TestCase):
|
class AuthTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
|
||||||
self.hs.handlers = AuthHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
self.auth_handler = self.hs.handlers.auth_handler
|
self.auth_handler = self.hs.handlers.auth_handler
|
||||||
self.macaroon_generator = self.hs.get_macaroon_generator()
|
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||||
@ -81,9 +81,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
self.hs.clock.now = 1000
|
self.hs.clock.now = 1000
|
||||||
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||||
"a_user", 5000
|
|
||||||
)
|
|
||||||
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
@ -98,17 +96,13 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||||
"a_user", 5000
|
|
||||||
)
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
macaroon.serialize()
|
macaroon.serialize()
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual("a_user", user_id)
|
||||||
"a_user", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# add another "user_id" caveat, which might allow us to override the
|
# add another "user_id" caveat, which might allow us to override the
|
||||||
# user_id.
|
# user_id.
|
||||||
@ -165,7 +159,5 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_macaroon(self):
|
def _get_macaroon(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
|
||||||
"user_a", 5000
|
|
||||||
)
|
|
||||||
return pymacaroons.Macaroon.deserialize(token)
|
return pymacaroons.Macaroon.deserialize(token)
|
||||||
|
@ -28,13 +28,13 @@ user2 = "@theresa:bbb"
|
|||||||
class DeviceTestCase(unittest.TestCase):
|
class DeviceTestCase(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||||
self.store = None # type: synapse.storage.DataStore
|
self.store = None # type: synapse.storage.DataStore
|
||||||
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||||
self.clock = None # type: utils.MockClock
|
self.clock = None # type: utils.MockClock
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield utils.setup_test_homeserver()
|
hs = yield utils.setup_test_homeserver(self.addCleanup)
|
||||||
self.handler = hs.get_device_handler()
|
self.handler = hs.get_device_handler()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
@ -44,7 +44,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
res = yield self.handler.check_device_registered(
|
res = yield self.handler.check_device_registered(
|
||||||
user_id="@boris:foo",
|
user_id="@boris:foo",
|
||||||
device_id="fco",
|
device_id="fco",
|
||||||
initial_device_display_name="display name"
|
initial_device_display_name="display name",
|
||||||
)
|
)
|
||||||
self.assertEqual(res, "fco")
|
self.assertEqual(res, "fco")
|
||||||
|
|
||||||
@ -56,14 +56,14 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
res1 = yield self.handler.check_device_registered(
|
res1 = yield self.handler.check_device_registered(
|
||||||
user_id="@boris:foo",
|
user_id="@boris:foo",
|
||||||
device_id="fco",
|
device_id="fco",
|
||||||
initial_device_display_name="display name"
|
initial_device_display_name="display name",
|
||||||
)
|
)
|
||||||
self.assertEqual(res1, "fco")
|
self.assertEqual(res1, "fco")
|
||||||
|
|
||||||
res2 = yield self.handler.check_device_registered(
|
res2 = yield self.handler.check_device_registered(
|
||||||
user_id="@boris:foo",
|
user_id="@boris:foo",
|
||||||
device_id="fco",
|
device_id="fco",
|
||||||
initial_device_display_name="new display name"
|
initial_device_display_name="new display name",
|
||||||
)
|
)
|
||||||
self.assertEqual(res2, "fco")
|
self.assertEqual(res2, "fco")
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
device_id = yield self.handler.check_device_registered(
|
device_id = yield self.handler.check_device_registered(
|
||||||
user_id="@theresa:foo",
|
user_id="@theresa:foo",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
initial_device_display_name="display"
|
initial_device_display_name="display",
|
||||||
)
|
)
|
||||||
|
|
||||||
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
|
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
|
||||||
@ -87,43 +87,53 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
res = yield self.handler.get_devices_by_user(user1)
|
res = yield self.handler.get_devices_by_user(user1)
|
||||||
self.assertEqual(3, len(res))
|
self.assertEqual(3, len(res))
|
||||||
device_map = {
|
device_map = {d["device_id"]: d for d in res}
|
||||||
d["device_id"]: d for d in res
|
self.assertDictContainsSubset(
|
||||||
}
|
{
|
||||||
self.assertDictContainsSubset({
|
"user_id": user1,
|
||||||
"user_id": user1,
|
"device_id": "xyz",
|
||||||
"device_id": "xyz",
|
"display_name": "display 0",
|
||||||
"display_name": "display 0",
|
"last_seen_ip": None,
|
||||||
"last_seen_ip": None,
|
"last_seen_ts": None,
|
||||||
"last_seen_ts": None,
|
},
|
||||||
}, device_map["xyz"])
|
device_map["xyz"],
|
||||||
self.assertDictContainsSubset({
|
)
|
||||||
"user_id": user1,
|
self.assertDictContainsSubset(
|
||||||
"device_id": "fco",
|
{
|
||||||
"display_name": "display 1",
|
"user_id": user1,
|
||||||
"last_seen_ip": "ip1",
|
"device_id": "fco",
|
||||||
"last_seen_ts": 1000000,
|
"display_name": "display 1",
|
||||||
}, device_map["fco"])
|
"last_seen_ip": "ip1",
|
||||||
self.assertDictContainsSubset({
|
"last_seen_ts": 1000000,
|
||||||
"user_id": user1,
|
},
|
||||||
"device_id": "abc",
|
device_map["fco"],
|
||||||
"display_name": "display 2",
|
)
|
||||||
"last_seen_ip": "ip3",
|
self.assertDictContainsSubset(
|
||||||
"last_seen_ts": 3000000,
|
{
|
||||||
}, device_map["abc"])
|
"user_id": user1,
|
||||||
|
"device_id": "abc",
|
||||||
|
"display_name": "display 2",
|
||||||
|
"last_seen_ip": "ip3",
|
||||||
|
"last_seen_ts": 3000000,
|
||||||
|
},
|
||||||
|
device_map["abc"],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_device(self):
|
def test_get_device(self):
|
||||||
yield self._record_users()
|
yield self._record_users()
|
||||||
|
|
||||||
res = yield self.handler.get_device(user1, "abc")
|
res = yield self.handler.get_device(user1, "abc")
|
||||||
self.assertDictContainsSubset({
|
self.assertDictContainsSubset(
|
||||||
"user_id": user1,
|
{
|
||||||
"device_id": "abc",
|
"user_id": user1,
|
||||||
"display_name": "display 2",
|
"device_id": "abc",
|
||||||
"last_seen_ip": "ip3",
|
"display_name": "display 2",
|
||||||
"last_seen_ts": 3000000,
|
"last_seen_ip": "ip3",
|
||||||
}, res)
|
"last_seen_ts": 3000000,
|
||||||
|
},
|
||||||
|
res,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_delete_device(self):
|
def test_delete_device(self):
|
||||||
@ -153,8 +163,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
def test_update_unknown_device(self):
|
def test_update_unknown_device(self):
|
||||||
update = {"display_name": "new_display"}
|
update = {"display_name": "new_display"}
|
||||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||||
yield self.handler.update_device("user_id", "unknown_device_id",
|
yield self.handler.update_device("user_id", "unknown_device_id", update)
|
||||||
update)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _record_users(self):
|
def _record_users(self):
|
||||||
@ -168,16 +177,17 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
|
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _record_user(self, user_id, device_id, display_name,
|
def _record_user(
|
||||||
access_token=None, ip=None):
|
self, user_id, device_id, display_name, access_token=None, ip=None
|
||||||
|
):
|
||||||
device_id = yield self.handler.check_device_registered(
|
device_id = yield self.handler.check_device_registered(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
initial_device_display_name=display_name
|
initial_device_display_name=display_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ip is not None:
|
if ip is not None:
|
||||||
yield self.store.insert_client_ip(
|
yield self.store.insert_client_ip(
|
||||||
user_id,
|
user_id, access_token, ip, "user_agent", device_id
|
||||||
access_token, ip, "user_agent", device_id)
|
)
|
||||||
self.clock.advance_time(1000)
|
self.clock.advance_time(1000)
|
||||||
|
@ -42,9 +42,11 @@ class DirectoryTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def register_query_handler(query_type, handler):
|
def register_query_handler(query_type, handler):
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
self.mock_registry.register_query_handler = register_query_handler
|
self.mock_registry.register_query_handler = register_query_handler
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
|
self.addCleanup,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
resource_for_federation=Mock(),
|
resource_for_federation=Mock(),
|
||||||
federation_client=self.mock_federation,
|
federation_client=self.mock_federation,
|
||||||
@ -68,10 +70,7 @@ class DirectoryTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
result = yield self.handler.get_association(self.my_room)
|
result = yield self.handler.get_association(self.my_room)
|
||||||
|
|
||||||
self.assertEquals({
|
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
||||||
"room_id": "!8765qwer:test",
|
|
||||||
"servers": ["test"],
|
|
||||||
}, result)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_remote_association(self):
|
def test_get_remote_association(self):
|
||||||
@ -81,16 +80,13 @@ class DirectoryTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
result = yield self.handler.get_association(self.remote_room)
|
result = yield self.handler.get_association(self.remote_room)
|
||||||
|
|
||||||
self.assertEquals({
|
self.assertEquals(
|
||||||
"room_id": "!8765qwer:test",
|
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
|
||||||
"servers": ["test", "remote"],
|
)
|
||||||
}, result)
|
|
||||||
self.mock_federation.make_query.assert_called_with(
|
self.mock_federation.make_query.assert_called_with(
|
||||||
destination="remote",
|
destination="remote",
|
||||||
query_type="directory",
|
query_type="directory",
|
||||||
args={
|
args={"room_alias": "#another:remote"},
|
||||||
"room_alias": "#another:remote",
|
|
||||||
},
|
|
||||||
retry_on_dns_fail=False,
|
retry_on_dns_fail=False,
|
||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
@ -105,7 +101,4 @@ class DirectoryTestCase(unittest.TestCase):
|
|||||||
{"room_alias": "#your-room:test"}
|
{"room_alias": "#your-room:test"}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals({
|
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
|
||||||
"room_id": "!8765asdf:test",
|
|
||||||
"servers": ["test"],
|
|
||||||
}, response)
|
|
||||||
|
@ -28,14 +28,13 @@ from tests import unittest, utils
|
|||||||
class E2eKeysHandlerTestCase(unittest.TestCase):
|
class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
|
super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
|
||||||
self.hs = None # type: synapse.server.HomeServer
|
self.hs = None # type: synapse.server.HomeServer
|
||||||
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
|
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = yield utils.setup_test_homeserver(
|
self.hs = yield utils.setup_test_homeserver(
|
||||||
handlers=None,
|
self.addCleanup, handlers=None, federation_client=mock.Mock()
|
||||||
federation_client=mock.Mock(),
|
|
||||||
)
|
)
|
||||||
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
|
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
|
||||||
|
|
||||||
@ -54,30 +53,21 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {
|
keys = {
|
||||||
"alg1:k1": "key1",
|
"alg1:k1": "key1",
|
||||||
"alg2:k2": {
|
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||||
"key": "key2",
|
"alg2:k3": {"key": "key3"},
|
||||||
"signatures": {"k1": "sig1"}
|
|
||||||
},
|
|
||||||
"alg2:k3": {
|
|
||||||
"key": "key3",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res = yield self.handler.upload_keys_for_user(
|
res = yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": keys},
|
local_user, device_id, {"one_time_keys": keys}
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
|
||||||
})
|
|
||||||
|
|
||||||
# we should be able to change the signature without a problem
|
# we should be able to change the signature without a problem
|
||||||
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
|
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
|
||||||
res = yield self.handler.upload_keys_for_user(
|
res = yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": keys},
|
local_user, device_id, {"one_time_keys": keys}
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
|
||||||
})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_change_one_time_keys(self):
|
def test_change_one_time_keys(self):
|
||||||
@ -87,25 +77,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {
|
keys = {
|
||||||
"alg1:k1": "key1",
|
"alg1:k1": "key1",
|
||||||
"alg2:k2": {
|
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||||
"key": "key2",
|
"alg2:k3": {"key": "key3"},
|
||||||
"signatures": {"k1": "sig1"}
|
|
||||||
},
|
|
||||||
"alg2:k3": {
|
|
||||||
"key": "key3",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res = yield self.handler.upload_keys_for_user(
|
res = yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": keys},
|
local_user, device_id, {"one_time_keys": keys}
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.handler.upload_keys_for_user(
|
yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
|
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
|
||||||
)
|
)
|
||||||
self.fail("No error when changing string key")
|
self.fail("No error when changing string key")
|
||||||
except errors.SynapseError:
|
except errors.SynapseError:
|
||||||
@ -113,7 +96,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.handler.upload_keys_for_user(
|
yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
|
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
|
||||||
)
|
)
|
||||||
self.fail("No error when replacing dict key with string")
|
self.fail("No error when replacing dict key with string")
|
||||||
except errors.SynapseError:
|
except errors.SynapseError:
|
||||||
@ -121,9 +104,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.handler.upload_keys_for_user(
|
yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {
|
local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
|
||||||
"one_time_keys": {"alg1:k1": {"key": "key"}}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.fail("No error when replacing string key with dict")
|
self.fail("No error when replacing string key with dict")
|
||||||
except errors.SynapseError:
|
except errors.SynapseError:
|
||||||
@ -131,13 +112,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.handler.upload_keys_for_user(
|
yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {
|
local_user,
|
||||||
|
device_id,
|
||||||
|
{
|
||||||
"one_time_keys": {
|
"one_time_keys": {
|
||||||
"alg2:k2": {
|
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
|
||||||
"key": "key3",
|
}
|
||||||
"signatures": {"k1": "sig1"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.fail("No error when replacing dict key")
|
self.fail("No error when replacing dict key")
|
||||||
@ -148,31 +128,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
def test_claim_one_time_key(self):
|
def test_claim_one_time_key(self):
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {
|
keys = {"alg1:k1": "key1"}
|
||||||
"alg1:k1": "key1",
|
|
||||||
}
|
|
||||||
|
|
||||||
res = yield self.handler.upload_keys_for_user(
|
res = yield self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": keys},
|
local_user, device_id, {"one_time_keys": keys}
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
|
||||||
"one_time_key_counts": {"alg1": 1}
|
|
||||||
})
|
|
||||||
|
|
||||||
res2 = yield self.handler.claim_one_time_keys({
|
res2 = yield self.handler.claim_one_time_keys(
|
||||||
"one_time_keys": {
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
local_user: {
|
)
|
||||||
device_id: "alg1"
|
self.assertEqual(
|
||||||
}
|
res2,
|
||||||
}
|
{
|
||||||
}, timeout=None)
|
"failures": {},
|
||||||
self.assertEqual(res2, {
|
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
||||||
"failures": {},
|
},
|
||||||
"one_time_keys": {
|
)
|
||||||
local_user: {
|
|
||||||
device_id: {
|
|
||||||
"alg1:k1": "key1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
@ -39,8 +39,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
prev_state = UserPresenceState.default(user_id)
|
prev_state = UserPresenceState.default(user_id)
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now
|
||||||
last_active_ts=now,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
@ -54,23 +53,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(state.last_federation_update_ts, now)
|
self.assertEquals(state.last_federation_update_ts, now)
|
||||||
|
|
||||||
self.assertEquals(wheel_timer.insert.call_count, 3)
|
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||||
wheel_timer.insert.assert_has_calls([
|
wheel_timer.insert.assert_has_calls(
|
||||||
call(
|
[
|
||||||
now=now,
|
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_active_ts + IDLE_TIMER
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
],
|
||||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
any_order=True,
|
||||||
),
|
)
|
||||||
], any_order=True)
|
|
||||||
|
|
||||||
def test_online_to_online(self):
|
def test_online_to_online(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
@ -79,14 +77,11 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
prev_state = UserPresenceState.default(user_id)
|
prev_state = UserPresenceState.default(user_id)
|
||||||
prev_state = prev_state.copy_and_replace(
|
prev_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
|
||||||
last_active_ts=now,
|
|
||||||
currently_active=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now
|
||||||
last_active_ts=now,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
@ -101,23 +96,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(state.last_federation_update_ts, now)
|
self.assertEquals(state.last_federation_update_ts, now)
|
||||||
|
|
||||||
self.assertEquals(wheel_timer.insert.call_count, 3)
|
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||||
wheel_timer.insert.assert_has_calls([
|
wheel_timer.insert.assert_has_calls(
|
||||||
call(
|
[
|
||||||
now=now,
|
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_active_ts + IDLE_TIMER
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
],
|
||||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
any_order=True,
|
||||||
),
|
)
|
||||||
], any_order=True)
|
|
||||||
|
|
||||||
def test_online_to_online_last_active_noop(self):
|
def test_online_to_online_last_active_noop(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
@ -132,8 +126,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now
|
||||||
last_active_ts=now,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
@ -148,23 +141,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(state.last_federation_update_ts, now)
|
self.assertEquals(state.last_federation_update_ts, now)
|
||||||
|
|
||||||
self.assertEquals(wheel_timer.insert.call_count, 3)
|
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||||
wheel_timer.insert.assert_has_calls([
|
wheel_timer.insert.assert_has_calls(
|
||||||
call(
|
[
|
||||||
now=now,
|
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_active_ts + IDLE_TIMER
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
],
|
||||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
any_order=True,
|
||||||
),
|
)
|
||||||
], any_order=True)
|
|
||||||
|
|
||||||
def test_online_to_online_last_active(self):
|
def test_online_to_online_last_active(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
@ -178,9 +170,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
currently_active=True,
|
currently_active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
|
||||||
state=PresenceState.ONLINE,
|
|
||||||
)
|
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||||
@ -193,18 +183,17 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(state.last_federation_update_ts, now)
|
self.assertEquals(state.last_federation_update_ts, now)
|
||||||
|
|
||||||
self.assertEquals(wheel_timer.insert.call_count, 2)
|
self.assertEquals(wheel_timer.insert.call_count, 2)
|
||||||
wheel_timer.insert.assert_has_calls([
|
wheel_timer.insert.assert_has_calls(
|
||||||
call(
|
[
|
||||||
now=now,
|
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||||
obj=user_id,
|
call(
|
||||||
then=new_state.last_active_ts + IDLE_TIMER
|
now=now,
|
||||||
),
|
obj=user_id,
|
||||||
call(
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||||
now=now,
|
),
|
||||||
obj=user_id,
|
],
|
||||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
any_order=True,
|
||||||
)
|
)
|
||||||
], any_order=True)
|
|
||||||
|
|
||||||
def test_remote_ping_timer(self):
|
def test_remote_ping_timer(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
@ -213,13 +202,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
prev_state = UserPresenceState.default(user_id)
|
prev_state = UserPresenceState.default(user_id)
|
||||||
prev_state = prev_state.copy_and_replace(
|
prev_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now
|
||||||
last_active_ts=now,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
|
||||||
state=PresenceState.ONLINE,
|
|
||||||
)
|
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now
|
prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now
|
||||||
@ -232,13 +218,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(new_state.status_msg, state.status_msg)
|
self.assertEquals(new_state.status_msg, state.status_msg)
|
||||||
|
|
||||||
self.assertEquals(wheel_timer.insert.call_count, 1)
|
self.assertEquals(wheel_timer.insert.call_count, 1)
|
||||||
wheel_timer.insert.assert_has_calls([
|
wheel_timer.insert.assert_has_calls(
|
||||||
call(
|
[
|
||||||
now=now,
|
call(
|
||||||
obj=user_id,
|
now=now,
|
||||||
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT
|
obj=user_id,
|
||||||
),
|
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
|
||||||
], any_order=True)
|
)
|
||||||
|
],
|
||||||
|
any_order=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_online_to_offline(self):
|
def test_online_to_offline(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
@ -247,14 +236,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
prev_state = UserPresenceState.default(user_id)
|
prev_state = UserPresenceState.default(user_id)
|
||||||
prev_state = prev_state.copy_and_replace(
|
prev_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
|
||||||
last_active_ts=now,
|
|
||||||
currently_active=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(state=PresenceState.OFFLINE)
|
||||||
state=PresenceState.OFFLINE,
|
|
||||||
)
|
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||||
@ -273,14 +258,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
prev_state = UserPresenceState.default(user_id)
|
prev_state = UserPresenceState.default(user_id)
|
||||||
prev_state = prev_state.copy_and_replace(
|
prev_state = prev_state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
|
||||||
last_active_ts=now,
|
|
||||||
currently_active=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = prev_state.copy_and_replace(
|
new_state = prev_state.copy_and_replace(state=PresenceState.UNAVAILABLE)
|
||||||
state=PresenceState.UNAVAILABLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
state, persist_and_notify, federation_ping = handle_update(
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||||
@ -293,13 +274,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(new_state.status_msg, state.status_msg)
|
self.assertEquals(new_state.status_msg, state.status_msg)
|
||||||
|
|
||||||
self.assertEquals(wheel_timer.insert.call_count, 1)
|
self.assertEquals(wheel_timer.insert.call_count, 1)
|
||||||
wheel_timer.insert.assert_has_calls([
|
wheel_timer.insert.assert_has_calls(
|
||||||
call(
|
[
|
||||||
now=now,
|
call(
|
||||||
obj=user_id,
|
now=now,
|
||||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
obj=user_id,
|
||||||
)
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||||
], any_order=True)
|
)
|
||||||
|
],
|
||||||
|
any_order=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
@ -314,9 +298,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||||||
last_user_sync_ts=now,
|
last_user_sync_ts=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = handle_timeout(
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
|
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
|
||||||
@ -332,9 +314,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||||||
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
|
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = handle_timeout(
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
self.assertEquals(new_state.state, PresenceState.OFFLINE)
|
self.assertEquals(new_state.state, PresenceState.OFFLINE)
|
||||||
@ -369,9 +349,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||||||
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
|
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = handle_timeout(
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
self.assertEquals(new_state, new_state)
|
self.assertEquals(new_state, new_state)
|
||||||
@ -388,9 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||||||
last_federation_update_ts=now,
|
last_federation_update_ts=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = handle_timeout(
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNone(new_state)
|
self.assertIsNone(new_state)
|
||||||
|
|
||||||
@ -425,9 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||||||
last_federation_update_ts=now,
|
last_federation_update_ts=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = handle_timeout(
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
self.assertEquals(state, new_state)
|
self.assertEquals(state, new_state)
|
||||||
|
@ -48,15 +48,14 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
self.mock_registry.register_query_handler = register_query_handler
|
self.mock_registry.register_query_handler = register_query_handler
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
|
self.addCleanup,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
handlers=None,
|
handlers=None,
|
||||||
resource_for_federation=Mock(),
|
resource_for_federation=Mock(),
|
||||||
federation_client=self.mock_federation,
|
federation_client=self.mock_federation,
|
||||||
federation_server=Mock(),
|
federation_server=Mock(),
|
||||||
federation_registry=self.mock_registry,
|
federation_registry=self.mock_registry,
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||||
"send_message",
|
|
||||||
])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
@ -74,9 +73,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_my_name(self):
|
def test_get_my_name(self):
|
||||||
yield self.store.set_profile_displayname(
|
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||||
self.frank.localpart, "Frank"
|
|
||||||
)
|
|
||||||
|
|
||||||
displayname = yield self.handler.get_displayname(self.frank)
|
displayname = yield self.handler.get_displayname(self.frank)
|
||||||
|
|
||||||
@ -85,22 +82,18 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_name(self):
|
def test_set_my_name(self):
|
||||||
yield self.handler.set_displayname(
|
yield self.handler.set_displayname(
|
||||||
self.frank,
|
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||||
synapse.types.create_requester(self.frank),
|
|
||||||
"Frank Jr."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
(yield self.store.get_profile_displayname(self.frank.localpart)),
|
(yield self.store.get_profile_displayname(self.frank.localpart)),
|
||||||
"Frank Jr."
|
"Frank Jr.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_name_noauth(self):
|
def test_set_my_name_noauth(self):
|
||||||
d = self.handler.set_displayname(
|
d = self.handler.set_displayname(
|
||||||
self.frank,
|
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||||
synapse.types.create_requester(self.bob),
|
|
||||||
"Frank Jr."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.assertFailure(d, AuthError)
|
yield self.assertFailure(d, AuthError)
|
||||||
@ -145,11 +138,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_avatar(self):
|
def test_set_my_avatar(self):
|
||||||
yield self.handler.set_avatar_url(
|
yield self.handler.set_avatar_url(
|
||||||
self.frank, synapse.types.create_requester(self.frank),
|
self.frank,
|
||||||
"http://my.server/pic.gif"
|
synapse.types.create_requester(self.frank),
|
||||||
|
"http://my.server/pic.gif",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||||
"http://my.server/pic.gif"
|
"http://my.server/pic.gif",
|
||||||
)
|
)
|
||||||
|
@ -17,7 +17,7 @@ from mock import Mock
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import RegistrationError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
@ -40,13 +40,15 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
self.mock_distributor.declare("registered_user")
|
self.mock_distributor.declare("registered_user")
|
||||||
self.mock_captcha_client = Mock()
|
self.mock_captcha_client = Mock()
|
||||||
self.hs = yield setup_test_homeserver(
|
self.hs = yield setup_test_homeserver(
|
||||||
|
self.addCleanup,
|
||||||
handlers=None,
|
handlers=None,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
expire_access_token=True,
|
expire_access_token=True,
|
||||||
profile_handler=Mock(),
|
profile_handler=Mock(),
|
||||||
)
|
)
|
||||||
self.macaroon_generator = Mock(
|
self.macaroon_generator = Mock(
|
||||||
generate_access_token=Mock(return_value='secret'))
|
generate_access_token=Mock(return_value='secret')
|
||||||
|
)
|
||||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
@ -62,7 +64,8 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
user_id = "@someone:test"
|
user_id = "@someone:test"
|
||||||
requester = create_requester("@as:test")
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
requester, local_part, display_name)
|
requester, local_part, display_name
|
||||||
|
)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
@ -73,13 +76,15 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
yield store.register(
|
yield store.register(
|
||||||
user_id=frank.to_string(),
|
user_id=frank.to_string(),
|
||||||
token="jkv;g498752-43gj['eamb!-5",
|
token="jkv;g498752-43gj['eamb!-5",
|
||||||
password_hash=None)
|
password_hash=None,
|
||||||
|
)
|
||||||
local_part = "frank"
|
local_part = "frank"
|
||||||
display_name = "Frank"
|
display_name = "Frank"
|
||||||
user_id = "@frank:test"
|
user_id = "@frank:test"
|
||||||
requester = create_requester("@as:test")
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
requester, local_part, display_name)
|
requester, local_part, display_name
|
||||||
|
)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
@ -104,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.lots_of_users)
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(RegistrationError):
|
with self.assertRaises(AuthError):
|
||||||
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -113,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.lots_of_users)
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(RegistrationError):
|
with self.assertRaises(AuthError):
|
||||||
yield self.handler.register(localpart="local_part")
|
yield self.handler.register(localpart="local_part")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -122,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.lots_of_users)
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(RegistrationError):
|
with self.assertRaises(AuthError):
|
||||||
yield self.handler.register_saml2(localpart="local_part")
|
yield self.handler.register_saml2(localpart="local_part")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user