mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-11-08 15:32:49 -05:00
Merge branch 'develop' into send_sni_for_federation_requests
This commit is contained in:
commit
7041cd872b
156 changed files with 7063 additions and 5235 deletions
|
|
@ -11,23 +11,44 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.replication.tcp.client import (
|
||||
ReplicationClientFactory,
|
||||
ReplicationClientHandler,
|
||||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class TestReplicationClientHandler(ReplicationClientHandler):
|
||||
"""Overrides on_rdata so that we can wait for it to happen"""
|
||||
def __init__(self, store):
|
||||
super(TestReplicationClientHandler, self).__init__(store)
|
||||
self._rdata_awaiters = []
|
||||
|
||||
def await_replication(self):
|
||||
d = Deferred()
|
||||
self._rdata_awaiters.append(d)
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
awaiters = self._rdata_awaiters
|
||||
self._rdata_awaiters = []
|
||||
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
|
||||
with PreserveLoggingContext():
|
||||
for a in awaiters:
|
||||
a.callback(None)
|
||||
|
||||
|
||||
class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
|
|
@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
|||
self.addCleanup(listener.stopListening)
|
||||
self.streamer = server_factory.streamer
|
||||
|
||||
self.replication_handler = ReplicationClientHandler(self.slaved_store)
|
||||
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
|
||||
client_factory = ReplicationClientFactory(
|
||||
self.hs, "client_name", self.replication_handler
|
||||
)
|
||||
|
|
@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
|||
self.addCleanup(client_factory.stopTrying)
|
||||
self.addCleanup(client_connector.disconnect)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
yield self.streamer.on_notifier_poke()
|
||||
d = self.replication_handler.await_sync("replication_test")
|
||||
self.streamer.send_sync_to_all_connections("replication_test")
|
||||
yield d
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
# xxx: should we be more specific in what we wait for?
|
||||
d = self.replication_handler.await_replication()
|
||||
self.streamer.on_notifier_poke()
|
||||
return d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check(self, method, args, expected_result=None):
|
||||
|
|
|
|||
|
|
@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
state_ids = {
|
||||
key: e.event_id for key, e in state.items()
|
||||
}
|
||||
context = EventContext()
|
||||
context.current_state_ids = state_ids
|
||||
context.prev_state_ids = state_ids
|
||||
context = EventContext.with_state(
|
||||
state_group=None,
|
||||
current_state_ids=state_ids,
|
||||
prev_state_ids=state_ids
|
||||
)
|
||||
else:
|
||||
state_handler = self.hs.get_state_handler()
|
||||
context = yield state_handler.compute_event_context(event)
|
||||
|
|
|
|||
305
tests/rest/client/v1/test_admin.py
Normal file
305
tests/rest/client/v1/test_admin.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v1.admin import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
|
||||
class UserRegisterTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
self.clock = ThreadedMemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.url = "/_matrix/client/r0/admin/register"
|
||||
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
self.device_handler = Mock()
|
||||
self.device_handler.check_device_registered = Mock(return_value="FAKE")
|
||||
|
||||
self.datastore = Mock(return_value=Mock())
|
||||
self.datastore.get_current_state_deltas = Mock(return_value=[])
|
||||
|
||||
self.secrets = Mock()
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.hs.config.registration_shared_secret = u"shared"
|
||||
|
||||
self.hs.get_media_repository = Mock()
|
||||
self.hs.get_deactivate_account_handler = Mock()
|
||||
|
||||
self.resource = JsonResource(self.hs)
|
||||
register_servlets(self.hs, self.resource)
|
||||
|
||||
def test_disabled(self):
|
||||
"""
|
||||
If there is no shared secret, registration through this method will be
|
||||
prevented.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
request, channel = make_request("POST", self.url, b'{}')
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(
|
||||
'Shared secret registration is not enabled', channel.json_body["error"]
|
||||
)
|
||||
|
||||
def test_get_nonce(self):
|
||||
"""
|
||||
Calling GET on the endpoint will return a randomised nonce, using the
|
||||
homeserver's secrets provider.
|
||||
"""
|
||||
secrets = Mock()
|
||||
secrets.token_hex = Mock(return_value="abcd")
|
||||
|
||||
self.hs.get_secrets = Mock(return_value=secrets)
|
||||
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(channel.json_body, {"nonce": "abcd"})
|
||||
|
||||
def test_expired_nonce(self):
|
||||
"""
|
||||
Calling GET on the endpoint will return a randomised nonce, which will
|
||||
only last for SALT_TIMEOUT (60s).
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
# 59 seconds
|
||||
self.clock.advance(59)
|
||||
|
||||
body = json.dumps({"nonce": nonce})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||
|
||||
# 61 seconds
|
||||
self.clock.advance(2)
|
||||
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||
|
||||
def test_register_incorrect_nonce(self):
|
||||
"""
|
||||
Only the provided nonce can be used, as it's checked in the MAC.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"nonce": nonce,
|
||||
"username": "bob",
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"mac": want_mac,
|
||||
}
|
||||
).encode('utf8')
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("HMAC incorrect", channel.json_body["error"])
|
||||
|
||||
def test_register_correct_nonce(self):
|
||||
"""
|
||||
When the correct nonce is provided, and the right key is provided, the
|
||||
user is registered.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"nonce": nonce,
|
||||
"username": "bob",
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"mac": want_mac,
|
||||
}
|
||||
).encode('utf8')
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||
|
||||
def test_nonce_reuse(self):
|
||||
"""
|
||||
A valid unrecognised nonce.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"nonce": nonce,
|
||||
"username": "bob",
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"mac": want_mac,
|
||||
}
|
||||
).encode('utf8')
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||
|
||||
# Now, try and reuse it
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||
|
||||
def test_missing_parts(self):
|
||||
"""
|
||||
Synapse will complain if you don't give nonce, username, password, and
|
||||
mac. Admin is optional. Additional checks are done for length and
|
||||
type.
|
||||
"""
|
||||
def nonce():
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
return channel.json_body["nonce"]
|
||||
|
||||
#
|
||||
# Nonce check
|
||||
#
|
||||
|
||||
# Must be present
|
||||
body = json.dumps({})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('nonce must be specified', channel.json_body["error"])
|
||||
|
||||
#
|
||||
# Username checks
|
||||
#
|
||||
|
||||
# Must be present
|
||||
body = json.dumps({"nonce": nonce()})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||
|
||||
# Must be a string
|
||||
body = json.dumps({"nonce": nonce(), "username": 1234})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
#
|
||||
# Username checks
|
||||
#
|
||||
|
||||
# Must be present
|
||||
body = json.dumps({"nonce": nonce(), "username": "a"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('password must be specified', channel.json_body["error"])
|
||||
|
||||
# Must be a string
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
||||
# Super long
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
|
@ -14,100 +14,30 @@
|
|||
# limitations under the License.
|
||||
|
||||
""" Tests REST events for /events paths."""
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
from six import PY3
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.rest.client.v1.events
|
||||
import synapse.rest.client.v1.register
|
||||
import synapse.rest.client.v1.room
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
from .utils import RestTestCase
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/api/v1"
|
||||
|
||||
|
||||
class EventStreamPaginationApiTestCase(unittest.TestCase):
|
||||
""" Tests event streaming query parameters and start/end keys used in the
|
||||
Pagination stream API. """
|
||||
user_id = "sid1"
|
||||
|
||||
def setUp(self):
|
||||
# configure stream and inject items
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def TODO_test_long_poll(self):
|
||||
# stream from 'end' key, send (self+other) message, expect message.
|
||||
|
||||
# stream from 'END', send (self+other) message, expect message.
|
||||
|
||||
# stream from 'end' key, send (self+other) topic, expect topic.
|
||||
|
||||
# stream from 'END', send (self+other) topic, expect topic.
|
||||
|
||||
# stream from 'end' key, send (self+other) invite, expect invite.
|
||||
|
||||
# stream from 'END', send (self+other) invite, expect invite.
|
||||
|
||||
pass
|
||||
|
||||
def TODO_test_stream_forward(self):
|
||||
# stream from START, expect injected items
|
||||
|
||||
# stream from 'start' key, expect same content
|
||||
|
||||
# stream from 'end' key, expect nothing
|
||||
|
||||
# stream from 'END', expect nothing
|
||||
|
||||
# The following is needed for cases where content is removed e.g. you
|
||||
# left a room, so the token you're streaming from is > the one that
|
||||
# would be returned naturally from START>END.
|
||||
# stream from very new token (higher than end key), expect same token
|
||||
# returned as end key
|
||||
pass
|
||||
|
||||
def TODO_test_limits(self):
|
||||
# stream from a key, expect limit_num items
|
||||
|
||||
# stream from START, expect limit_num items
|
||||
|
||||
pass
|
||||
|
||||
def TODO_test_range(self):
|
||||
# stream from key to key, expect X items
|
||||
|
||||
# stream from key to END, expect X items
|
||||
|
||||
# stream from START to key, expect X items
|
||||
|
||||
# stream from START to END, expect all items
|
||||
pass
|
||||
|
||||
def TODO_test_direction(self):
|
||||
# stream from END to START and fwds, expect newest first
|
||||
|
||||
# stream from END to START and bwds, expect oldest first
|
||||
|
||||
# stream from START to END and fwds, expect oldest first
|
||||
|
||||
# stream from START to END and bwds, expect newest first
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EventStreamPermissionsTestCase(RestTestCase):
|
||||
""" Tests event streaming (GET /events). """
|
||||
|
||||
if PY3:
|
||||
skip = "Skip on Py3 until ported to use not V1 only register."
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
import synapse.rest.client.v1.events
|
||||
import synapse.rest.client.v1_only.register
|
||||
import synapse.rest.client.v1.room
|
||||
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
|
|
@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,27 +16,26 @@
|
|||
import json
|
||||
|
||||
from mock import Mock
|
||||
from six import PY3
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
|
||||
from synapse.rest.client.v1.register import CreateUserRestServlet
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v1_only.register import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import mock_getRawHeaders
|
||||
from tests.server import make_request, setup_test_homeserver
|
||||
|
||||
|
||||
class CreateUserServletTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for CreateUserRestServlet.
|
||||
"""
|
||||
if PY3:
|
||||
skip = "Not ported to Python 3."
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
path='/_matrix/client/api/v1/createUser'
|
||||
)
|
||||
self.request.args = {}
|
||||
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
self.registration_handler = Mock()
|
||||
|
||||
self.appservice = Mock(sender="@as:test")
|
||||
|
|
@ -44,39 +43,49 @@ class CreateUserServletTestCase(unittest.TestCase):
|
|||
get_app_service_by_token=Mock(return_value=self.appservice)
|
||||
)
|
||||
|
||||
# do the dance to hook things up to the hs global
|
||||
handlers = Mock(
|
||||
registration_handler=self.registration_handler,
|
||||
handlers = Mock(registration_handler=self.registration_handler)
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||
self.hs.get_handlers = Mock(return_value=handlers)
|
||||
self.servlet = CreateUserRestServlet(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_createuser_with_valid_user(self):
|
||||
|
||||
res = JsonResource(self.hs)
|
||||
register_servlets(self.hs, res)
|
||||
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"localpart": "someone",
|
||||
"displayname": "someone interesting",
|
||||
"duration_seconds": 200,
|
||||
}
|
||||
)
|
||||
|
||||
url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
|
||||
|
||||
user_id = "@someone:interesting"
|
||||
token = "my token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"localpart": "someone",
|
||||
"displayname": "someone interesting",
|
||||
"duration_seconds": 200
|
||||
})
|
||||
|
||||
self.registration_handler.get_or_create_user = Mock(
|
||||
return_value=(user_id, token)
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
request, channel = make_request(b"POST", url, request_data)
|
||||
request.render(res)
|
||||
|
||||
# Advance the clock because it waits
|
||||
self.clock.advance(1)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200")
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -16,13 +16,14 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
# twisted imports
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
# trial imports
|
||||
from tests import unittest
|
||||
from tests.server import make_request, wait_until_result
|
||||
|
||||
|
||||
class RestTestCase(unittest.TestCase):
|
||||
|
|
@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase):
|
|||
for key in required:
|
||||
self.assertEquals(required[key], actual[key],
|
||||
msg="%s mismatch. %s" % (key, actual))
|
||||
|
||||
|
||||
@attr.s
|
||||
class RestHelper(object):
|
||||
"""Contains extra helper functions to quickly and clearly perform a given
|
||||
REST action, which isn't the focus of the test.
|
||||
"""
|
||||
|
||||
hs = attr.ib()
|
||||
resource = attr.ib()
|
||||
auth_user_id = attr.ib()
|
||||
|
||||
def create_room_as(self, room_creator, is_public=True, tok=None):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = room_creator
|
||||
path = b"/_matrix/client/r0/createRoom"
|
||||
content = {}
|
||||
if not is_public:
|
||||
content["visibility"] = "private"
|
||||
if tok:
|
||||
path = path + b"?access_token=%s" % tok.encode('ascii')
|
||||
|
||||
request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.hs.get_reactor(), channel)
|
||||
|
||||
assert channel.result["code"] == b"200", channel.result
|
||||
self.auth_user_id = temp_id
|
||||
return channel.json_body["room_id"]
|
||||
|
||||
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
room=room,
|
||||
src=src,
|
||||
targ=targ,
|
||||
tok=tok,
|
||||
membership=Membership.INVITE,
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def join(self, room=None, user=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
room=room,
|
||||
src=user,
|
||||
targ=user,
|
||||
tok=tok,
|
||||
membership=Membership.JOIN,
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def leave(self, room=None, user=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
room=room,
|
||||
src=user,
|
||||
targ=user,
|
||||
tok=tok,
|
||||
membership=Membership.LEAVE,
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = src
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
data = {"membership": membership}
|
||||
|
||||
request, channel = make_request(
|
||||
b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
|
||||
)
|
||||
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.hs.get_reactor(), channel)
|
||||
|
||||
assert int(channel.result["code"]) == expect_code, (
|
||||
"Expected: %d, got: %d, resp: %r"
|
||||
% (expect_code, int(channel.result["code"]), channel.result["body"])
|
||||
)
|
||||
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, user_id):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST",
|
||||
"/_matrix/client/r0/register",
|
||||
json.dumps(
|
||||
{"user": user_id, "password": "test", "type": "m.login.password"}
|
||||
),
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||
content = '{"msgtype":"m.text","body":"%s"}' % body
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
|
||||
self.assertEquals(expect_code, code, msg=str(response))
|
||||
|
|
|
|||
|
|
@ -1,61 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class V2AlphaRestTestCase(unittest.TestCase):
|
||||
# Consumer must define
|
||||
# USER_ID = <some string>
|
||||
# TO_REGISTER = [<list of REST servlets to register>]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
datastore=self.make_datastore_mock(),
|
||||
http_client=None,
|
||||
resource_for_client=self.mock_resource,
|
||||
resource_for_federation=self.mock_resource,
|
||||
)
|
||||
|
||||
def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.USER_ID),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(hs, self.mock_resource)
|
||||
|
||||
def make_datastore_mock(self):
|
||||
store = Mock(spec=[
|
||||
"insert_client_ip",
|
||||
])
|
||||
store.get_app_service_by_token = Mock(return_value=None)
|
||||
return store
|
||||
|
|
@ -13,35 +13,37 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import filter
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
||||
make_request,
|
||||
setup_test_homeserver,
|
||||
wait_until_result,
|
||||
)
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class FilterTestCase(unittest.TestCase):
|
||||
|
||||
USER_ID = "@apple:test"
|
||||
USER_ID = b"@apple:test"
|
||||
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
||||
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
TO_REGISTER = [filter]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = yield setup_test_homeserver(
|
||||
http_client=None,
|
||||
resource_for_client=self.mock_resource,
|
||||
resource_for_federation=self.mock_resource,
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.auth = self.hs.get_auth()
|
||||
|
|
@ -55,82 +57,103 @@ class FilterTestCase(unittest.TestCase):
|
|||
|
||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
return synapse.types.create_requester(
|
||||
UserID.from_string(self.USER_ID), 1, False, None)
|
||||
UserID.from_string(self.USER_ID), 1, False, None
|
||||
)
|
||||
|
||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
||||
self.auth.get_user_by_req = get_user_by_req
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.filtering = self.hs.get_filtering()
|
||||
self.resource = JsonResource(self.hs)
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(self.hs, self.mock_resource)
|
||||
r.register_servlets(self.hs, self.resource)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter(self):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
|
||||
request, channel = make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals({"filter_id": "0"}, response)
|
||||
filter = yield self.store.get_user_filter(
|
||||
user_localpart='apple',
|
||||
filter_id=0,
|
||||
)
|
||||
self.assertEquals(filter, self.EXAMPLE_FILTER)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
self.clock.advance(0)
|
||||
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter_for_other_user(self):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
|
||||
request, channel = make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
self.assertEquals(403, code)
|
||||
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter_non_local_user(self):
|
||||
_is_mine = self.hs.is_mine
|
||||
self.hs.is_mine = lambda target_user: False
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
|
||||
request, channel = make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.hs.is_mine = _is_mine
|
||||
self.assertEquals(403, code)
|
||||
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter(self):
|
||||
filter_id = yield self.filtering.add_user_filter(
|
||||
user_localpart='apple',
|
||||
user_filter=self.EXAMPLE_FILTER
|
||||
filter_id = self.filtering.add_user_filter(
|
||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||
)
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
||||
self.clock.advance(1)
|
||||
filter_id = filter_id.result
|
||||
request, channel = make_request(
|
||||
b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals(self.EXAMPLE_FILTER, response)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter_non_existant(self):
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/12382148321" % (self.USER_ID)
|
||||
request, channel = make_request(
|
||||
b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
|
||||
)
|
||||
self.assertEquals(400, code)
|
||||
self.assertEquals(response['errcode'], Codes.NOT_FOUND)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
# Currently invalid params do not have an appropriate errcode
|
||||
# in errors.py
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter_invalid_id(self):
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/foobar" % (self.USER_ID)
|
||||
request, channel = make_request(
|
||||
b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
|
||||
)
|
||||
self.assertEquals(400, code)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
||||
# No ID also returns an invalid_id error
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter_no_id(self):
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/" % (self.USER_ID)
|
||||
request, channel = make_request(
|
||||
b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
|
||||
)
|
||||
self.assertEquals(400, code)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
|
|
|||
|
|
@ -2,165 +2,192 @@ import json
|
|||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python import failure
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha.register import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import mock_getRawHeaders
|
||||
from tests.server import make_request, setup_test_homeserver, wait_until_result
|
||||
|
||||
|
||||
class RegisterRestServletTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
path='/_matrix/api/v2_alpha/register'
|
||||
)
|
||||
self.request.args = {}
|
||||
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.url = b"/_matrix/client/r0/register"
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(get_appservice_by_req=Mock(
|
||||
side_effect=lambda x: self.appservice)
|
||||
self.auth = Mock(
|
||||
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
|
||||
)
|
||||
|
||||
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None)
|
||||
get_session_data=Mock(return_value=None),
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
self.device_handler = Mock()
|
||||
self.device_handler.check_device_registered = Mock(return_value="FAKE")
|
||||
|
||||
self.datastore = Mock(return_value=Mock())
|
||||
self.datastore.get_current_state_deltas = Mock(return_value=[])
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
login_handler=self.login_handler,
|
||||
)
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||
self.hs.config.enable_registration = True
|
||||
self.hs.config.registrations_require_3pid = []
|
||||
self.hs.config.auto_join_rooms = []
|
||||
|
||||
# init the thing we're testing
|
||||
self.servlet = RegisterRestServlet(self.hs)
|
||||
self.resource = JsonResource(self.hs)
|
||||
register_servlets(self.hs, self.resource)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = {
|
||||
"id": "1234"
|
||||
}
|
||||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=user_id
|
||||
)
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||
return_value=token
|
||||
)
|
||||
self.appservice = {"id": "1234"}
|
||||
self.registration_handler.appservice_register = Mock(return_value=user_id)
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
request, channel = make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_invalid(self):
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = None # no application service exists
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (401, None))
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
request, channel = make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
|
||||
def test_POST_bad_password(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": 666
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"], "Invalid password"
|
||||
)
|
||||
|
||||
def test_POST_bad_username(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": 777,
|
||||
"password": "monkey"
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"], "Invalid username"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_user_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
device_id = "frogfone"
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
"device_id": device_id,
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||
return_value=token
|
||||
request_data = json.dumps(
|
||||
{"username": "kermit", "password": "monkey", "device_id": device_id}
|
||||
)
|
||||
self.device_handler.check_device_registered = \
|
||||
Mock(return_value=device_id)
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id, initial_device_display_name=None)
|
||||
user_id, device_id=device_id, initial_device_display_name=None
|
||||
)
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.enable_registration = False
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
}, None)
|
||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"],
|
||||
"Registration has been disabled",
|
||||
)
|
||||
|
||||
def test_POST_guest_registration(self):
|
||||
user_id = "a@b"
|
||||
self.hs.config.macaroon_secret_key = "test"
|
||||
self.hs.config.allow_guest_access = True
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
|
||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": "guest_device",
|
||||
}
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
|
||||
def test_POST_disabled_guest_registration(self):
|
||||
self.hs.config.allow_guest_access = False
|
||||
|
||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"], "Guest access is disabled"
|
||||
)
|
||||
|
|
|
|||
87
tests/rest/client/v2_alpha/test_sync.py
Normal file
87
tests/rest/client/v2_alpha/test_sync.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse.types
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
||||
make_request,
|
||||
setup_test_homeserver,
|
||||
wait_until_result,
|
||||
)
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class FilterTestCase(unittest.TestCase):
|
||||
|
||||
USER_ID = b"@apple:test"
|
||||
TO_REGISTER = [sync]
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.auth = self.hs.get_auth()
|
||||
|
||||
def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.USER_ID),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
return synapse.types.create_requester(
|
||||
UserID.from_string(self.USER_ID), 1, False, None
|
||||
)
|
||||
|
||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
||||
self.auth.get_user_by_req = get_user_by_req
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.filtering = self.hs.get_filtering()
|
||||
self.resource = JsonResource(self.hs)
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(self.hs, self.resource)
|
||||
|
||||
def test_sync_argless(self):
|
||||
request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertTrue(
|
||||
set(
|
||||
[
|
||||
"next_batch",
|
||||
"rooms",
|
||||
"presence",
|
||||
"account_data",
|
||||
"to_device",
|
||||
"device_lists",
|
||||
]
|
||||
).issubset(set(channel.json_body.keys()))
|
||||
)
|
||||
|
|
@ -80,6 +80,11 @@ def make_request(method, path, content=b""):
|
|||
content, and return the Request and the Channel underneath.
|
||||
"""
|
||||
|
||||
# Decorate it to be the full path
|
||||
if not path.startswith(b"/_matrix"):
|
||||
path = b"/_matrix/client/r0/" + path
|
||||
path = path.replace("//", "/")
|
||||
|
||||
if isinstance(content, text_type):
|
||||
content = content.encode('utf8')
|
||||
|
||||
|
|
@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100):
|
|||
clock.advance(0.1)
|
||||
|
||||
|
||||
def render(request, resource, clock):
|
||||
request.render(resource)
|
||||
wait_until_result(clock, request._channel)
|
||||
|
||||
|
||||
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||
"""
|
||||
A MemoryReactorClock that supports callFromThread.
|
||||
|
|
|
|||
319
tests/storage/test_state.py
Normal file
319
tests/storage/test_state.py
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.types import RoomID, UserID
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateStoreTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(StateStoreTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield tests.utils.setup_test_homeserver()
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
self.u_alice = UserID.from_string("@alice:test")
|
||||
self.u_bob = UserID.from_string("@bob:test")
|
||||
|
||||
self.room = RoomID.from_string("!abc123:test")
|
||||
|
||||
yield self.store.store_room(
|
||||
self.room.to_string(),
|
||||
room_creator_user_id="@creator:text",
|
||||
is_public=True
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_state_event(self, room, sender, typ, state_key, content):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": typ,
|
||||
"sender": sender.to_string(),
|
||||
"state_key": state_key,
|
||||
"room_id": room.to_string(),
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
def assertStateMapEqual(self, s1, s2):
|
||||
for t in s1:
|
||||
# just compare event IDs for simplicity
|
||||
self.assertEqual(s1[t].event_id, s2[t].event_id)
|
||||
self.assertEqual(len(s1), len(s2))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_state_for_event(self):
|
||||
|
||||
# this defaults to a linear DAG as each new injection defaults to whatever
|
||||
# forward extremities are currently in the DB for this room.
|
||||
e1 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Create, '', {},
|
||||
)
|
||||
e2 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Name, '', {
|
||||
"name": "test room"
|
||||
},
|
||||
)
|
||||
e3 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {
|
||||
"membership": Membership.JOIN
|
||||
},
|
||||
)
|
||||
e4 = yield self.inject_state_event(
|
||||
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
|
||||
"membership": Membership.JOIN
|
||||
},
|
||||
)
|
||||
e5 = yield self.inject_state_event(
|
||||
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
|
||||
"membership": Membership.LEAVE
|
||||
},
|
||||
)
|
||||
|
||||
# check we get the full state as of the final event
|
||||
state = yield self.store.get_state_for_event(
|
||||
e5.event_id, None, filtered_types=None
|
||||
)
|
||||
|
||||
self.assertIsNotNone(e4)
|
||||
|
||||
self.assertStateMapEqual({
|
||||
(e1.type, e1.state_key): e1,
|
||||
(e2.type, e2.state_key): e2,
|
||||
(e3.type, e3.state_key): e3,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5,
|
||||
}, state)
|
||||
|
||||
# check we can filter to the m.room.name event (with a '' state key)
|
||||
state = yield self.store.get_state_for_event(
|
||||
e5.event_id, [(EventTypes.Name, '')], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({
|
||||
(e2.type, e2.state_key): e2,
|
||||
}, state)
|
||||
|
||||
# check we can filter to the m.room.name event (with a wildcard None state key)
|
||||
state = yield self.store.get_state_for_event(
|
||||
e5.event_id, [(EventTypes.Name, None)], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({
|
||||
(e2.type, e2.state_key): e2,
|
||||
}, state)
|
||||
|
||||
# check we can grab the m.room.member events (with a wildcard None state key)
|
||||
state = yield self.store.get_state_for_event(
|
||||
e5.event_id, [(EventTypes.Member, None)], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({
|
||||
(e3.type, e3.state_key): e3,
|
||||
(e5.type, e5.state_key): e5,
|
||||
}, state)
|
||||
|
||||
# check we can use filter_types to grab a specific room member
|
||||
# without filtering out the other event types
|
||||
state = yield self.store.get_state_for_event(
|
||||
e5.event_id, [(EventTypes.Member, self.u_alice.to_string())],
|
||||
filtered_types=[EventTypes.Member],
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({
|
||||
(e1.type, e1.state_key): e1,
|
||||
(e2.type, e2.state_key): e2,
|
||||
(e3.type, e3.state_key): e3,
|
||||
}, state)
|
||||
|
||||
# check that types=[], filtered_types=[EventTypes.Member]
|
||||
# doesn't return all members
|
||||
state = yield self.store.get_state_for_event(
|
||||
e5.event_id, [], filtered_types=[EventTypes.Member],
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({
|
||||
(e1.type, e1.state_key): e1,
|
||||
(e2.type, e2.state_key): e2,
|
||||
}, state)
|
||||
|
||||
#######################################################
|
||||
# _get_some_state_from_cache tests against a full cache
|
||||
#######################################################
|
||||
|
||||
room_id = self.room.to_string()
|
||||
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
|
||||
group = group_ids.keys()[0]
|
||||
|
||||
# test _get_some_state_from_cache correctly filters out members with types=[]
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with wildcard types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
# and no filtered_types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
}, state_dict)
|
||||
|
||||
#######################################################
|
||||
# deliberately remove e2 (room name) from the _state_group_cache
|
||||
|
||||
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertEqual(known_absent, set())
|
||||
self.assertDictEqual(state_dict_ids, {
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
})
|
||||
|
||||
state_dict_ids.pop((e2.type, e2.state_key))
|
||||
self.store._state_group_cache.invalidate(group)
|
||||
self.store._state_group_cache.update(
|
||||
sequence=self.store._state_group_cache.sequence,
|
||||
key=group,
|
||||
value=state_dict_ids,
|
||||
# list fetched keys so it knows it's partial
|
||||
fetched_keys=(
|
||||
(e1.type, e1.state_key),
|
||||
(e3.type, e3.state_key),
|
||||
(e5.type, e5.state_key),
|
||||
)
|
||||
)
|
||||
|
||||
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertEqual(known_absent, set([
|
||||
(e1.type, e1.state_key),
|
||||
(e3.type, e3.state_key),
|
||||
(e5.type, e5.state_key),
|
||||
]))
|
||||
self.assertDictEqual(state_dict_ids, {
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
})
|
||||
|
||||
############################################
|
||||
# test that things work with a partial cache
|
||||
|
||||
# test _get_some_state_from_cache correctly filters out members with types=[]
|
||||
room_id = self.room.to_string()
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members wildcard types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
# and no filtered_types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
}, state_dict)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
@ -15,8 +16,6 @@
|
|||
|
||||
from mock import Mock, patch
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.distributor import Distributor
|
||||
|
||||
from . import unittest
|
||||
|
|
@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.dist = Distributor()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch(self):
|
||||
self.dist.declare("alert")
|
||||
|
||||
observer = Mock()
|
||||
self.dist.observe("alert", observer)
|
||||
|
||||
d = self.dist.fire("alert", 1, 2, 3)
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
self.dist.fire("alert", 1, 2, 3)
|
||||
observer.assert_called_with(1, 2, 3)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch_deferred(self):
|
||||
self.dist.declare("whine")
|
||||
|
||||
d_inner = defer.Deferred()
|
||||
|
||||
def observer():
|
||||
return d_inner
|
||||
|
||||
self.dist.observe("whine", observer)
|
||||
|
||||
d_outer = self.dist.fire("whine")
|
||||
|
||||
self.assertFalse(d_outer.called)
|
||||
|
||||
d_inner.callback(None)
|
||||
yield d_outer
|
||||
self.assertTrue(d_outer.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_catch(self):
|
||||
self.dist.declare("alarm")
|
||||
|
||||
|
|
@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
|
|||
with patch(
|
||||
"synapse.util.distributor.logger", spec=["warning"]
|
||||
) as mock_logger:
|
||||
d = self.dist.fire("alarm", "Go")
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
self.dist.fire("alarm", "Go")
|
||||
|
||||
observers[0].assert_called_once_with("Go")
|
||||
observers[1].assert_called_once_with("Go")
|
||||
|
|
@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase):
|
|||
mock_logger.warning.call_args[0][0], str
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_catch_no_suppress(self):
|
||||
# Gut-wrenching
|
||||
self.dist.suppress_failures = False
|
||||
|
||||
self.dist.declare("whail")
|
||||
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def observer():
|
||||
raise MyException("Oopsie")
|
||||
|
||||
self.dist.observe("whail", observer)
|
||||
|
||||
d = self.dist.fire("whail")
|
||||
|
||||
yield self.assertFailure(d, MyException)
|
||||
self.dist.suppress_failures = True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_prereg(self):
|
||||
observer = Mock()
|
||||
self.dist.observe("flare", observer)
|
||||
|
||||
self.dist.declare("flare")
|
||||
yield self.dist.fire("flare", 4, 5)
|
||||
self.dist.fire("flare", 4, 5)
|
||||
|
||||
observer.assert_called_with(4, 5)
|
||||
|
||||
|
|
|
|||
|
|
@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
|
||||
|
||||
@unittest.DEBUG
|
||||
def test_cant_hide_past_history(self):
|
||||
"""
|
||||
If you send a message, you must be able to provide the direct
|
||||
|
|
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
for x, y in d.items()
|
||||
if x == ("m.room.member", "@us:test")
|
||||
],
|
||||
"auth_chain_ids": d.values(),
|
||||
"auth_chain_ids": list(d.values()),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase):
|
|||
return (200, kwargs)
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)
|
||||
res.register_paths(
|
||||
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
|
||||
)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
||||
|
|
@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
raise Exception("boo")
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'500')
|
||||
|
|
@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
return d
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request.render(res)
|
||||
|
||||
# No error has been raised yet
|
||||
|
|
@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'403')
|
||||
|
|
@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
self.fail("shouldn't ever get here")
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foobar")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foobar")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'400')
|
||||
|
|
|
|||
|
|
@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
|
|||
self.store.register_event_context(event, context)
|
||||
context_store[event.event_id] = context
|
||||
|
||||
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
||||
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
|
||||
self.assertEqual(2, len(prev_state_ids))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_branch_basic_conflict(self):
|
||||
|
|
@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
|
|||
self.store.register_event_context(event, context)
|
||||
context_store[event.event_id] = context
|
||||
|
||||
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
|
||||
|
||||
self.assertSetEqual(
|
||||
{"START", "A", "C"},
|
||||
{e_id for e_id in context_store["D"].prev_state_ids.values()}
|
||||
{e_id for e_id in prev_state_ids.values()}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
|
|||
self.store.register_event_context(event, context)
|
||||
context_store[event.event_id] = context
|
||||
|
||||
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
|
||||
|
||||
self.assertSetEqual(
|
||||
{"START", "A", "B", "C"},
|
||||
{e for e in context_store["E"].prev_state_ids.values()}
|
||||
{e for e in prev_state_ids.values()}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
|
|||
self.store.register_event_context(event, context)
|
||||
context_store[event.event_id] = context
|
||||
|
||||
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
|
||||
|
||||
self.assertSetEqual(
|
||||
{"A1", "A2", "A3", "A5", "B"},
|
||||
{e for e in context_store["D"].prev_state_ids.values()}
|
||||
{e for e in prev_state_ids.values()}
|
||||
)
|
||||
|
||||
def _add_depths(self, nodes, edges):
|
||||
|
|
@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
|
|||
event, old_state=old_state
|
||||
)
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
self.assertEqual(
|
||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||
set(e.event_id for e in old_state), set(current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNotNone(context.state_group)
|
||||
|
|
@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
|
|||
event, old_state=old_state
|
||||
)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
|
||||
self.assertEqual(
|
||||
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
|
||||
set(e.event_id for e in old_state), set(prev_state_ids.values())
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set(context.current_state_ids.values())
|
||||
set(current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertEqual(group_name, context.state_group)
|
||||
|
|
@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set(context.prev_state_ids.values())
|
||||
set(prev_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNotNone(context.state_group)
|
||||
|
|
@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
|
|||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||
)
|
||||
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
self.assertEqual(len(current_state_ids), 6)
|
||||
|
||||
self.assertIsNotNone(context.state_group)
|
||||
|
||||
|
|
@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
|
|||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||
)
|
||||
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
self.assertEqual(len(current_state_ids), 6)
|
||||
|
||||
self.assertIsNotNone(context.state_group)
|
||||
|
||||
|
|
@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
|
|||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||
)
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
self.assertEqual(
|
||||
old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
|
||||
old_state_2[3].event_id, current_state_ids[("test1", "1")]
|
||||
)
|
||||
|
||||
# Reverse the depth to make sure we are actually using the depths
|
||||
|
|
@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
|
|||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||
)
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
self.assertEqual(
|
||||
old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
|
||||
old_state_1[3].event_id, current_state_ids[("test1", "1")]
|
||||
)
|
||||
|
||||
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
|
||||
|
|
|
|||
324
tests/test_visibility.py
Normal file
324
tests/test_visibility.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import succeed
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
import tests.unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEST_ROOM_ID = "!TEST:ROOM"
|
||||
|
||||
|
||||
class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver()
|
||||
self.event_creation_handler = self.hs.get_event_creation_handler()
|
||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filtering(self):
|
||||
#
|
||||
# The events to be filtered consist of 10 membership events (it doesn't
|
||||
# really matter if they are joins or leaves, so let's make them joins).
|
||||
# One of those membership events is going to be for a user on the
|
||||
# server we are filtering for (so we can check the filtering is doing
|
||||
# the right thing).
|
||||
#
|
||||
|
||||
# before we do that, we persist some other events to act as state.
|
||||
self.inject_visibility("@admin:hs", "joined")
|
||||
for i in range(0, 10):
|
||||
yield self.inject_room_member("@resident%i:hs" % i)
|
||||
|
||||
events_to_filter = []
|
||||
|
||||
for i in range(0, 10):
|
||||
user = "@user%i:%s" % (
|
||||
i, "test_server" if i == 5 else "other_server"
|
||||
)
|
||||
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
|
||||
events_to_filter.append(evt)
|
||||
|
||||
filtered = yield filter_events_for_server(
|
||||
self.store, "test_server", events_to_filter,
|
||||
)
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
for i in range(0, 5):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertNotIn("a", filtered[i].content)
|
||||
|
||||
for i in range(5, 10):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertEqual(filtered[i].content["a"], "b")
|
||||
|
||||
@tests.unittest.DEBUG
|
||||
@defer.inlineCallbacks
|
||||
def test_erased_user(self):
|
||||
# 4 message events, from erased and unerased users, with a membership
|
||||
# change in the middle of them.
|
||||
events_to_filter = []
|
||||
|
||||
evt = yield self.inject_message("@unerased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_message("@erased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_room_member("@joiner:remote_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_message("@unerased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_message("@erased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
# the erasey user gets erased
|
||||
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
|
||||
|
||||
# ... and the filtering happens.
|
||||
filtered = yield filter_events_for_server(
|
||||
self.store, "test_server", events_to_filter,
|
||||
)
|
||||
|
||||
for i in range(0, len(events_to_filter)):
|
||||
self.assertEqual(
|
||||
events_to_filter[i].event_id, filtered[i].event_id,
|
||||
"Unexpected event at result position %i" % (i, )
|
||||
)
|
||||
|
||||
for i in (0, 3):
|
||||
self.assertEqual(
|
||||
events_to_filter[i].content["body"], filtered[i].content["body"],
|
||||
"Unexpected event content at result position %i" % (i,)
|
||||
)
|
||||
|
||||
for i in (1, 4):
|
||||
self.assertNotIn("body", filtered[i].content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_visibility(self, user_id, visibility):
|
||||
content = {"history_visibility": visibility}
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": "m.room.history_visibility",
|
||||
"sender": user_id,
|
||||
"state_key": "",
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
yield self.hs.get_datastore().persist_event(event, context)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, user_id, membership="join", extra_content={}):
|
||||
content = {"membership": membership}
|
||||
content.update(extra_content)
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": "m.room.member",
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.hs.get_datastore().persist_event(event, context)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_message(self, user_id, content=None):
|
||||
if content is None:
|
||||
content = {"body": "testytest"}
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": "m.room.message",
|
||||
"sender": user_id,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.hs.get_datastore().persist_event(event, context)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_large_room(self):
|
||||
# see what happens when we have a large room with hundreds of thousands
|
||||
# of membership events
|
||||
|
||||
# As above, the events to be filtered consist of 10 membership events,
|
||||
# where one of them is for a user on the server we are filtering for.
|
||||
|
||||
import cProfile
|
||||
import pstats
|
||||
import time
|
||||
|
||||
# we stub out the store, because building up all that state the normal
|
||||
# way is very slow.
|
||||
test_store = _TestStore()
|
||||
|
||||
# our initial state is 100000 membership events and one
|
||||
# history_visibility event.
|
||||
room_state = []
|
||||
|
||||
history_visibility_evt = FrozenEvent({
|
||||
"event_id": "$history_vis",
|
||||
"type": "m.room.history_visibility",
|
||||
"sender": "@resident_user_0:test.com",
|
||||
"state_key": "",
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": {"history_visibility": "joined"},
|
||||
})
|
||||
room_state.append(history_visibility_evt)
|
||||
test_store.add_event(history_visibility_evt)
|
||||
|
||||
for i in range(0, 100000):
|
||||
user = "@resident_user_%i:test.com" % (i, )
|
||||
evt = FrozenEvent({
|
||||
"event_id": "$res_event_%i" % (i, ),
|
||||
"type": "m.room.member",
|
||||
"state_key": user,
|
||||
"sender": user,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": {
|
||||
"membership": "join",
|
||||
"extra": "zzz,"
|
||||
},
|
||||
})
|
||||
room_state.append(evt)
|
||||
test_store.add_event(evt)
|
||||
|
||||
events_to_filter = []
|
||||
for i in range(0, 10):
|
||||
user = "@user%i:%s" % (
|
||||
i, "test_server" if i == 5 else "other_server"
|
||||
)
|
||||
evt = FrozenEvent({
|
||||
"event_id": "$evt%i" % (i, ),
|
||||
"type": "m.room.member",
|
||||
"state_key": user,
|
||||
"sender": user,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": {
|
||||
"membership": "join",
|
||||
"extra": "zzz",
|
||||
},
|
||||
})
|
||||
events_to_filter.append(evt)
|
||||
room_state.append(evt)
|
||||
|
||||
test_store.add_event(evt)
|
||||
test_store.set_state_ids_for_event(evt, {
|
||||
(e.type, e.state_key): e.event_id for e in room_state
|
||||
})
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
|
||||
logger.info("Starting filtering")
|
||||
start = time.time()
|
||||
filtered = yield filter_events_for_server(
|
||||
test_store, "test_server", events_to_filter,
|
||||
)
|
||||
logger.info("Filtering took %f seconds", time.time() - start)
|
||||
|
||||
pr.disable()
|
||||
with open("filter_events_for_server.profile", "w+") as f:
|
||||
ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
|
||||
ps.print_stats()
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
for i in range(0, 5):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertNotIn("extra", filtered[i].content)
|
||||
|
||||
for i in range(5, 10):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertEqual(filtered[i].content["extra"], "zzz")
|
||||
|
||||
test_large_room.skip = "Disabled by default because it's slow"
|
||||
|
||||
|
||||
class _TestStore(object):
|
||||
"""Implements a few methods of the DataStore, so that we can test
|
||||
filter_events_for_server
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
# data for get_events: a map from event_id to event
|
||||
self.events = {}
|
||||
|
||||
# data for get_state_ids_for_events mock: a map from event_id to
|
||||
# a map from (type_state_key) -> event_id for the state at that
|
||||
# event
|
||||
self.state_ids_for_events = {}
|
||||
|
||||
def add_event(self, event):
|
||||
self.events[event.event_id] = event
|
||||
|
||||
def set_state_ids_for_event(self, event, state):
|
||||
self.state_ids_for_events[event.event_id] = state
|
||||
|
||||
def get_state_ids_for_events(self, events, types):
|
||||
res = {}
|
||||
include_memberships = False
|
||||
for (type, state_key) in types:
|
||||
if type == "m.room.history_visibility":
|
||||
continue
|
||||
if type != "m.room.member" or state_key is not None:
|
||||
raise RuntimeError(
|
||||
"Unimplemented: get_state_ids with type (%s, %s)" %
|
||||
(type, state_key),
|
||||
)
|
||||
include_memberships = True
|
||||
|
||||
if include_memberships:
|
||||
for event_id in events:
|
||||
res[event_id] = self.state_ids_for_events[event_id]
|
||||
|
||||
else:
|
||||
k = ("m.room.history_visibility", "")
|
||||
for event_id in events:
|
||||
hve = self.state_ids_for_events[event_id][k]
|
||||
res[event_id] = {k: hve}
|
||||
|
||||
return succeed(res)
|
||||
|
||||
def get_events(self, events):
|
||||
return succeed({
|
||||
event_id: self.events[event_id] for event_id in events
|
||||
})
|
||||
|
||||
def are_users_erased(self, users):
|
||||
return succeed({u: False for u in users})
|
||||
|
|
@ -109,6 +109,17 @@ class TestCase(unittest.TestCase):
|
|||
except AssertionError as e:
|
||||
raise (type(e))(e.message + " for '.%s'" % key)
|
||||
|
||||
def assert_dict(self, required, actual):
|
||||
"""Does a partial assert of a dict.
|
||||
|
||||
Args:
|
||||
required (dict): The keys and value which MUST be in 'actual'.
|
||||
actual (dict): The test result. Extra keys will not be checked.
|
||||
"""
|
||||
for key in required:
|
||||
self.assertEquals(required[key], actual[key],
|
||||
msg="%s mismatch. %s" % (key, actual))
|
||||
|
||||
|
||||
def DEBUG(target):
|
||||
"""A decorator to set the .loglevel attribute to logging.DEBUG.
|
||||
|
|
|
|||
|
|
@ -1,70 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async import Limiter
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class LimiterTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_limiter(self):
|
||||
limiter = Limiter(3)
|
||||
|
||||
key = object()
|
||||
|
||||
d1 = limiter.queue(key)
|
||||
cm1 = yield d1
|
||||
|
||||
d2 = limiter.queue(key)
|
||||
cm2 = yield d2
|
||||
|
||||
d3 = limiter.queue(key)
|
||||
cm3 = yield d3
|
||||
|
||||
d4 = limiter.queue(key)
|
||||
self.assertFalse(d4.called)
|
||||
|
||||
d5 = limiter.queue(key)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm1:
|
||||
self.assertFalse(d4.called)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
self.assertTrue(d4.called)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm3:
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
self.assertTrue(d5.called)
|
||||
|
||||
with cm2:
|
||||
pass
|
||||
|
||||
with (yield d4):
|
||||
pass
|
||||
|
||||
with (yield d5):
|
||||
pass
|
||||
|
||||
d6 = limiter.queue(key)
|
||||
with (yield d6):
|
||||
pass
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
@ -16,6 +17,7 @@
|
|||
from six.moves import range
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
||||
from synapse.util import Clock, logcontext
|
||||
from synapse.util.async import Linearizer
|
||||
|
|
@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
|
|||
func(i)
|
||||
|
||||
return func(1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_multiple_entries(self):
|
||||
limiter = Linearizer(max_count=3)
|
||||
|
||||
key = object()
|
||||
|
||||
d1 = limiter.queue(key)
|
||||
cm1 = yield d1
|
||||
|
||||
d2 = limiter.queue(key)
|
||||
cm2 = yield d2
|
||||
|
||||
d3 = limiter.queue(key)
|
||||
cm3 = yield d3
|
||||
|
||||
d4 = limiter.queue(key)
|
||||
self.assertFalse(d4.called)
|
||||
|
||||
d5 = limiter.queue(key)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm1:
|
||||
self.assertFalse(d4.called)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
cm4 = yield d4
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm3:
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
cm5 = yield d5
|
||||
|
||||
with cm2:
|
||||
pass
|
||||
|
||||
with cm4:
|
||||
pass
|
||||
|
||||
with cm5:
|
||||
pass
|
||||
|
||||
d6 = limiter.queue(key)
|
||||
with (yield d6):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cancellation(self):
|
||||
linearizer = Linearizer()
|
||||
|
||||
key = object()
|
||||
|
||||
d1 = linearizer.queue(key)
|
||||
cm1 = yield d1
|
||||
|
||||
d2 = linearizer.queue(key)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
d3 = linearizer.queue(key)
|
||||
self.assertFalse(d3.called)
|
||||
|
||||
d2.cancel()
|
||||
|
||||
with cm1:
|
||||
pass
|
||||
|
||||
self.assertTrue(d2.called)
|
||||
try:
|
||||
yield d2
|
||||
self.fail("Expected d2 to raise CancelledError")
|
||||
except CancelledError:
|
||||
pass
|
||||
|
||||
with (yield d3):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -141,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Query all the entries mid-way through the stream, but include one
|
||||
# that doesn't exist in it. We should get back the one that doesn't
|
||||
# exist, too.
|
||||
# that doesn't exist in it. We shouldn't get back the one that doesn't
|
||||
# exist.
|
||||
self.assertEqual(
|
||||
cache.get_entities_changed(
|
||||
[
|
||||
|
|
@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||
],
|
||||
stream_pos=2,
|
||||
),
|
||||
set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
|
||||
set(["bar@baz.net", "user@elsewhere.org"]),
|
||||
)
|
||||
|
||||
# Query all the entries, but before the first known point. We will get
|
||||
|
|
@ -178,6 +178,22 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||
),
|
||||
)
|
||||
|
||||
# Query a subset of the entries mid-way through the stream. We should
|
||||
# only get back the subset.
|
||||
self.assertEqual(
|
||||
cache.get_entities_changed(
|
||||
[
|
||||
"bar@baz.net",
|
||||
],
|
||||
stream_pos=2,
|
||||
),
|
||||
set(
|
||||
[
|
||||
"bar@baz.net",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def test_max_pos(self):
|
||||
"""
|
||||
StreamChangeCache.get_max_pos_of_last_change will return the most
|
||||
|
|
|
|||
|
|
@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
|
|||
config.user_directory_search_all_users = False
|
||||
config.user_consent_server_notice_content = None
|
||||
config.block_events_without_consent_error = None
|
||||
config.media_storage_providers = []
|
||||
config.auto_join_rooms = []
|
||||
|
||||
# disable user directory updates, because they get done in the
|
||||
# background, which upsets the test runner.
|
||||
|
|
@ -138,6 +140,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
|
|||
room_list_handler=object(),
|
||||
tls_server_context_factory=Mock(),
|
||||
tls_client_options_factory=Mock(),
|
||||
reactor=reactor,
|
||||
**kargs
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue