Merge branch 'develop' into matthew/filter_members

This commit is contained in:
Matthew Hodgson 2018-07-23 19:21:37 +01:00
commit adfe29ec0b
53 changed files with 3683 additions and 3243 deletions

View file

@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = {
key: e.event_id for key, e in state.items()
}
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context = EventContext.with_state(
state_group=None,
current_state_ids=state_ids,
prev_state_ids=state_ids
)
else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)

View file

@ -0,0 +1,305 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import hmac
import json
from mock import Mock
from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
from synapse.util import Clock
from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class UserRegisterTestCase(unittest.TestCase):
def setUp(self):
self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
self.secrets = Mock()
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
self.resource = JsonResource(self.hs)
register_servlets(self.hs, self.resource)
def test_disabled(self):
"""
If there is no shared secret, registration through this method will be
prevented.
"""
self.hs.config.registration_shared_secret = None
request, channel = make_request("POST", self.url, b'{}')
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
'Shared secret registration is not enabled', channel.json_body["error"]
)
def test_get_nonce(self):
"""
Calling GET on the endpoint will return a randomised nonce, using the
homeserver's secrets provider.
"""
secrets = Mock()
secrets.token_hex = Mock(return_value="abcd")
self.hs.get_secrets = Mock(return_value=secrets)
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
def test_expired_nonce(self):
"""
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
# 59 seconds
self.clock.advance(59)
body = json.dumps({"nonce": nonce})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds
self.clock.advance(2)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
def test_register_incorrect_nonce(self):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
"""
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self):
"""
A valid unrecognised nonce.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
def test_missing_parts(self):
"""
Synapse will complain if you don't give nonce, username, password, and
mac. Admin is optional. Additional checks are done for length and
type.
"""
def nonce():
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
return channel.json_body["nonce"]
#
# Nonce check
#
# Must be present
body = json.dumps({})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
#
# Username checks
#
# Must be present
body = json.dumps({"nonce": nonce()})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
#
# Username checks
#
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])

View file

@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
@unittest.DEBUG
def test_cant_hide_past_history(self):
"""
If you send a message, you must be able to provide the direct
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
for x, y in d.items()
if x == ("m.room.member", "@us:test")
],
"auth_chain_ids": d.values(),
"auth_chain_ids": list(d.values()),
}
)

View file

@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids))
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "C"},
{e_id for e_id in context_store["D"].prev_state_ids.values()}
{e_id for e_id in prev_state_ids.values()}
)
@defer.inlineCallbacks
@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "B", "C"},
{e for e in context_store["E"].prev_state_ids.values()}
{e for e in prev_state_ids.values()}
)
@defer.inlineCallbacks
@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
{e for e in context_store["D"].prev_state_ids.values()}
{e for e in prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
set(e.event_id for e in old_state), set(context.current_state_ids.values())
set(e.event_id for e in old_state), set(current_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual(
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
set(e.event_id for e in old_state), set(prev_state_ids.values())
)
@defer.inlineCallbacks
@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
set([e.event_id for e in old_state]),
set(context.current_state_ids.values())
set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual(
set([e.event_id for e in old_state]),
set(context.prev_state_ids.values())
set(prev_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
old_state_2[3].event_id, current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
old_state_1[3].event_id, current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,

View file

@ -1,70 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.util.async import Limiter
from tests import unittest
class LimiterTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_limiter(self):
limiter = Limiter(3)
key = object()
d1 = limiter.queue(key)
cm1 = yield d1
d2 = limiter.queue(key)
cm2 = yield d2
d3 = limiter.queue(key)
cm3 = yield d3
d4 = limiter.queue(key)
self.assertFalse(d4.called)
d5 = limiter.queue(key)
self.assertFalse(d5.called)
with cm1:
self.assertFalse(d4.called)
self.assertFalse(d5.called)
self.assertTrue(d4.called)
self.assertFalse(d5.called)
with cm3:
self.assertFalse(d5.called)
self.assertTrue(d5.called)
with cm2:
pass
with (yield d4):
pass
with (yield d5):
pass
d6 = limiter.queue(key)
with (yield d6):
pass

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,6 +17,7 @@
from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer
@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
func(i)
return func(1000)
@defer.inlineCallbacks
def test_multiple_entries(self):
limiter = Linearizer(max_count=3)
key = object()
d1 = limiter.queue(key)
cm1 = yield d1
d2 = limiter.queue(key)
cm2 = yield d2
d3 = limiter.queue(key)
cm3 = yield d3
d4 = limiter.queue(key)
self.assertFalse(d4.called)
d5 = limiter.queue(key)
self.assertFalse(d5.called)
with cm1:
self.assertFalse(d4.called)
self.assertFalse(d5.called)
cm4 = yield d4
self.assertFalse(d5.called)
with cm3:
self.assertFalse(d5.called)
cm5 = yield d5
with cm2:
pass
with cm4:
pass
with cm5:
pass
d6 = limiter.queue(key)
with (yield d6):
pass
@defer.inlineCallbacks
def test_cancellation(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
d3 = linearizer.queue(key)
self.assertFalse(d3.called)
d2.cancel()
with cm1:
pass
self.assertTrue(d2.called)
try:
yield d2
self.fail("Expected d2 to raise CancelledError")
except CancelledError:
pass
with (yield d3):
pass

View file

@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.user_directory_search_all_users = False
config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None
config.media_storage_providers = []
config.auto_join_rooms = []
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
@ -136,6 +138,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
reactor=reactor,
**kargs
)