Convert additional test-cases to homeserver test case. (#9396)

And convert some inlineDeferreds to async-friendly functions.
This commit is contained in:
Patrick Cloke 2021-02-16 08:04:15 -05:00 committed by GitHub
parent ff40c8099d
commit 74af356baf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 214 additions and 331 deletions

1
changelog.d/9396.misc Normal file
View File

@ -0,0 +1 @@
Convert tests to use `HomeserverTestCase`.

View File

@ -17,8 +17,6 @@ from mock import Mock
import pymacaroons
from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
from tests.utils import mock_getRawHeaders, setup_test_homeserver
from tests.test_utils import simple_async_mock
from tests.utils import mock_getRawHeaders
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
class AuthTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = Mock()
self.hs = yield setup_test_homeserver(self.addCleanup)
self.hs.get_datastore = Mock(return_value=self.store)
self.hs.get_auth_handler().store = self.store
self.auth = Auth(self.hs)
hs.get_datastore = Mock(return_value=self.store)
hs.get_auth_handler().store = self.store
self.auth = Auth(hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = simple_async_mock(False)
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
f = self.get_failure(
self.auth.get_user_by_req(request), InvalidClientTokenError
).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
f = self.get_failure(
self.auth.get_user_by_req(request), MissingClientTokenError
).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet
@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
f = self.get_failure(
self.auth.get_user_by_req(request), InvalidClientTokenError
).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
f = self.get_failure(
self.auth.get_user_by_req(request), InvalidClientTokenError
).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
f = self.get_failure(
self.auth.get_user_by_req(request), MissingClientTokenError
).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase):
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = Mock(
return_value=defer.succeed({"is_guest": False})
)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@ -210,23 +201,19 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
self.get_failure(self.auth.get_user_by_req(request), AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield defer.ensureDeferred(
user_info = self.get_success(
self.auth.get_user_by_access_token(macaroon.serialize())
)
self.assertEqual(user_id, user_info.user_id)
@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase):
# from the db.
self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
self.assertEqual(user_id, user_info.user_id)
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
self.store.get_device = Mock(return_value=defer.succeed(None))
self.store.add_access_token_to_user = simple_async_mock(None)
self.store.get_device = simple_async_mock(None)
token = yield defer.ensureDeferred(
token = self.get_success(
self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
@ -289,25 +272,21 @@ class AuthTestCase(unittest.TestCase):
puppets_user_id=None,
)
def get_user(tok):
async def get_user(tok):
if token != tok:
return defer.succeed(None)
return defer.succeed(
TokenLookupResult(
return None
return TokenLookupResult(
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
)
)
self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock(
return_value=defer.succeed({"is_guest": False})
)
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
# check the token works
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(
requester = self.get_success(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
@ -323,17 +302,16 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm:
yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
cm = self.get_failure(
self.auth.get_user_by_req(request, allow_guest=True),
InvalidClientCredentialsError,
)
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
self.assertEqual(401, cm.value.code)
self.assertEqual("Guest access token used for regular user", cm.value.msg)
self.store.get_user_by_id.assert_called_with(USER_ID)
@defer.inlineCallbacks
def test_blocking_mau(self):
self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50
@ -341,77 +319,61 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1
# Ensure no error thrown
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.get_success(self.auth.check_auth_blocking())
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
# Ensure does not throw an error
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.get_success(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.BOT)
self.get_failure(
self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
@defer.inlineCallbacks
def test_reserved_threepid(self):
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
self.store.get_monthly_active_count = simple_async_mock(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
self.auth.check_auth_blocking(threepid=unknown_threepid)
self.get_failure(
self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
)
yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
self.get_success(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
@defer.inlineCallbacks
def test_hs_disabled_no_server_notices_user(self):
"""Check that 'hs_disabled_message' works correctly when there is no
server_notices user.
@ -422,16 +384,14 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
@defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
self.auth_blocking._hs_disabled = True
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
self.get_success(self.auth.check_auth_blocking(user))

View File

@ -18,15 +18,12 @@
import jsonschema
from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@ -39,9 +36,8 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs)
class FilteringTestCase(unittest.TestCase):
def setUp(self):
hs = setup_test_homeserver(self.addCleanup)
class FilteringTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore()
@ -351,10 +347,9 @@ class FilteringTestCase(unittest.TestCase):
self.assertTrue(Filter(definition).check(event))
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = yield defer.ensureDeferred(
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@ -362,7 +357,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
user_filter = yield defer.ensureDeferred(
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@ -371,11 +366,10 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = yield defer.ensureDeferred(
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
)
@ -387,7 +381,7 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
user_filter = yield defer.ensureDeferred(
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
@ -396,10 +390,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred(
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@ -407,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
user_filter = yield defer.ensureDeferred(
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@ -416,10 +409,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred(
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@ -429,7 +421,7 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
user_filter = yield defer.ensureDeferred(
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@ -454,11 +446,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
@defer.inlineCallbacks
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred(
filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@ -468,7 +459,7 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
yield defer.ensureDeferred(
self.get_success(
self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
@ -476,17 +467,16 @@ class FilteringTestCase(unittest.TestCase):
),
)
@defer.inlineCallbacks
def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred(
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
)
filter = yield defer.ensureDeferred(
filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)

View File

@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastore.return_value = self.mock_store
self.mock_store.get_received_ts.return_value = defer.succeed(0)
self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
self.mock_store.get_received_ts.return_value = make_awaitable(0)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested=False),
]
self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = defer.succeed([])
self.mock_store.get_user_by_id.return_value = make_awaitable([])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.side_effect = [
defer.succeed((0, [event])),
defer.succeed((0, [])),
make_awaitable((0, [event])),
make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = defer.succeed(None)
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
defer.succeed((0, [event])),
defer.succeed((0, [])),
make_awaitable((0, [event])),
make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
defer.succeed((0, [event])),
defer.succeed((0, [])),
make_awaitable((0, [event])),
make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.",
)
@defer.inlineCallbacks
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Mock(room_id=room_id, servers=servers)
)
result = yield defer.ensureDeferred(
self.handler.query_room_alias_exists(room_alias)
result = self.successResultOf(
defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
)
self.mock_as_api.query_alias.assert_called_once_with(

View File

@ -14,163 +14,11 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
import json
from mock import Mock
from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/r0"
class MockHandlerProfileTestCase(unittest.TestCase):
""" Tests rest layer of profile management.
Todo: move these into ProfileTestCase
"""
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.mock_handler = Mock(
spec=[
"get_displayname",
"set_displayname",
"get_avatar_url",
"set_avatar_url",
"check_profile_query_allowed",
]
)
self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
Mock()
)
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
federation_http_client=None,
resource_for_client=self.mock_resource,
federation=Mock(),
federation_client=Mock(),
profile_handler=self.mock_handler,
)
async def _get_user_by_req(request=None, allow_guest=False):
return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
profile.register_servlets(hs, self.mock_resource)
@defer.inlineCallbacks
def test_get_my_name(self):
mocked_get = self.mock_handler.get_displayname
mocked_get.return_value = defer.succeed("Frank")
(code, response) = yield self.mock_resource.trigger(
"GET", "/profile/%s/displayname" % (myid), None
)
self.assertEquals(200, code)
self.assertEquals({"displayname": "Frank"}, response)
self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
@defer.inlineCallbacks
def test_set_my_name(self):
mocked_set = self.mock_handler.set_displayname
mocked_set.return_value = defer.succeed(())
(code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
)
self.assertEquals(200, code)
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
mocked_set = self.mock_handler.set_displayname
mocked_set.side_effect = AuthError(400, "message")
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/displayname" % ("@4567:test"),
b'{"displayname": "Frank Jr."}',
)
self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
@defer.inlineCallbacks
def test_get_other_name(self):
mocked_get = self.mock_handler.get_displayname
mocked_get.return_value = defer.succeed("Bob")
(code, response) = yield self.mock_resource.trigger(
"GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
)
self.assertEquals(200, code)
self.assertEquals({"displayname": "Bob"}, response)
@defer.inlineCallbacks
def test_set_other_name(self):
mocked_set = self.mock_handler.set_displayname
mocked_set.side_effect = SynapseError(400, "message")
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/displayname" % ("@opaque:elsewhere"),
b'{"displayname":"bob"}',
)
self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
@defer.inlineCallbacks
def test_get_my_avatar(self):
mocked_get = self.mock_handler.get_avatar_url
mocked_get.return_value = defer.succeed("http://my.server/me.png")
(code, response) = yield self.mock_resource.trigger(
"GET", "/profile/%s/avatar_url" % (myid), None
)
self.assertEquals(200, code)
self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
@defer.inlineCallbacks
def test_set_my_avatar(self):
mocked_set = self.mock_handler.set_avatar_url
mocked_set.return_value = defer.succeed(())
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/avatar_url" % (myid),
b'{"avatar_url": "http://my.server/pic.gif"}',
)
self.assertEquals(200, code)
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
class ProfileTestCase(unittest.HomeserverTestCase):
@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
self.other = self.register_user("other", "pass", displayname="Bob")
def test_get_displayname(self):
res = self._get_displayname()
self.assertEqual(res, "owner")
def test_set_displayname(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test"}),
content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
res = self.get_displayname()
res = self._get_displayname()
self.assertEqual(res, "test")
def test_set_displayname_noauth(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content={"displayname": "test"},
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test" * 100}),
content={"displayname": "test" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
res = self.get_displayname()
res = self._get_displayname()
self.assertEqual(res, "owner")
def get_displayname(self):
channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
def test_get_displayname_other(self):
res = self._get_displayname(self.other)
self.assertEquals(res, "Bob")
def test_set_displayname_other(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.other,),
content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
def test_get_avatar_url(self):
res = self._get_avatar_url()
self.assertIsNone(res)
def test_set_avatar_url(self):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
res = self._get_avatar_url()
self.assertEqual(res, "http://my.server/pic.gif")
def test_set_avatar_url_noauth(self):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif"},
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_avatar_url_too_long(self):
"""Attempts to set a stupid avatar_url should get a 400"""
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
res = self._get_avatar_url()
self.assertIsNone(res)
def test_get_avatar_url_other(self):
res = self._get_avatar_url(self.other)
self.assertIsNone(res)
def test_set_avatar_url_other(self):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.other,),
content={"avatar_url": "http://my.server/pic.gif"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
def _get_displayname(self, name=None):
channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
def _get_avatar_url(self, name=None):
channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body.get("avatar_url")
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):