Convert additional test-cases to homeserver test case. (#9396)

And convert some inlineDeferreds to async-friendly functions.
This commit is contained in:
Patrick Cloke 2021-02-16 08:04:15 -05:00 committed by GitHub
parent ff40c8099d
commit 74af356baf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 214 additions and 331 deletions

1
changelog.d/9396.misc Normal file
View File

@ -0,0 +1 @@
Convert tests to use `HomeserverTestCase`.

View File

@ -17,8 +17,6 @@ from mock import Mock
import pymacaroons import pymacaroons
from twisted.internet import defer
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ( from synapse.api.errors import (
@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders, setup_test_homeserver from tests.test_utils import simple_async_mock
from tests.utils import mock_getRawHeaders
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
self.state_handler = Mock()
self.store = Mock() self.store = Mock()
self.hs = yield setup_test_homeserver(self.addCleanup) hs.get_datastore = Mock(return_value=self.store)
self.hs.get_datastore = Mock(return_value=self.store) hs.get_auth_handler().store = self.store
self.hs.get_auth_handler().store = self.store self.auth = Auth(hs)
self.auth = Auth(self.hs)
# AuthBlocking reads from the hs' config on initialization. We need to # AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs' # modify its config instead of the hs'
@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests # this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = Mock(return_value=defer.succeed(False)) self.store.is_support_user = simple_async_mock(False)
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = TokenLookupResult( user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device" user_id=self.test_user, token_id=5, device_id="device"
) )
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = simple_async_mock(user_info)
return_value=defer.succeed(user_info)
)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.get_failure(
f = self.failureResultOf(d, InvalidClientTokenError).value self.auth.get_user_by_req(request), InvalidClientTokenError
).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = TokenLookupResult(user_id=self.test_user, token_id=5) user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = simple_async_mock(user_info)
return_value=defer.succeed(user_info)
)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.get_failure(
f = self.failureResultOf(d, MissingClientTokenError).value self.auth.get_user_by_req(request), MissingClientTokenError
).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN") self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self): def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet from netaddr import IPSet
@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10" request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self): def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42" request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.get_failure(
f = self.failureResultOf(d, InvalidClientTokenError).value self.auth.get_user_by_req(request), InvalidClientTokenError
).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.get_failure(
f = self.failureResultOf(d, InvalidClientTokenError).value self.auth.get_user_by_req(request), InvalidClientTokenError
).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self): def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.get_failure(
f = self.failureResultOf(d, MissingClientTokenError).value self.auth.get_user_by_req(request), MissingClientTokenError
).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN") self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase):
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value. # This just needs to return a truth-y value.
self.store.get_user_by_id = Mock( self.store.get_user_by_id = simple_async_mock({"is_guest": False})
return_value=defer.succeed({"is_guest": False}) self.store.get_user_by_access_token = simple_async_mock(None)
)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals( self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8") requester.user.to_string(), masquerading_user_id.decode("utf8")
) )
@ -210,22 +201,18 @@ class AuthTestCase(unittest.TestCase):
) )
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.get_failure(self.auth.get_user_by_req(request), AuthError)
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self): def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = simple_async_mock(
return_value=defer.succeed( TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
) )
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield defer.ensureDeferred( user_info = self.get_success(
self.auth.get_user_by_access_token(macaroon.serialize()) self.auth.get_user_by_access_token(macaroon.serialize())
) )
self.assertEqual(user_id, user_info.user_id) self.assertEqual(user_id, user_info.user_id)
@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase):
# from the db. # from the db.
self.assertEqual(user_info.device_id, "device") self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true") macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize() serialized = macaroon.serialize()
user_info = yield defer.ensureDeferred( user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
self.auth.get_user_by_access_token(serialized)
)
self.assertEqual(user_id, user_info.user_id) self.assertEqual(user_id, user_info.user_id)
self.assertTrue(user_info.is_guest) self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self): def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org" USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) self.store.add_access_token_to_user = simple_async_mock(None)
self.store.get_device = Mock(return_value=defer.succeed(None)) self.store.get_device = simple_async_mock(None)
token = yield defer.ensureDeferred( token = self.get_success(
self.hs.get_auth_handler().get_access_token_for_user_id( self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None USER_ID, "DEVICE", valid_until_ms=None
) )
@ -289,25 +272,21 @@ class AuthTestCase(unittest.TestCase):
puppets_user_id=None, puppets_user_id=None,
) )
def get_user(tok): async def get_user(tok):
if token != tok: if token != tok:
return defer.succeed(None) return None
return defer.succeed( return TokenLookupResult(
TokenLookupResult( user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
)
) )
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock( self.store.get_user_by_id = simple_async_mock({"is_guest": False})
return_value=defer.succeed({"is_guest": False})
)
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")] request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield defer.ensureDeferred( requester = self.get_success(
self.auth.get_user_by_req(request, allow_guest=True) self.auth.get_user_by_req(request, allow_guest=True)
) )
self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertEqual(UserID.from_string(USER_ID), requester.user)
@ -323,17 +302,16 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [guest_tok.encode("ascii")] request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm: cm = self.get_failure(
yield defer.ensureDeferred( self.auth.get_user_by_req(request, allow_guest=True),
self.auth.get_user_by_req(request, allow_guest=True) InvalidClientCredentialsError,
) )
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.value.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.assertEqual("Guest access token used for regular user", cm.value.msg)
self.store.get_user_by_id.assert_called_with(USER_ID) self.store.get_user_by_id.assert_called_with(USER_ID)
@defer.inlineCallbacks
def test_blocking_mau(self): def test_blocking_mau(self):
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
@ -341,77 +319,61 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1 small_number_of_users = 1
# Ensure no error thrown # Ensure no error thrown
yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.get_success(self.auth.check_auth_blocking())
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
return_value=defer.succeed(lots_of_users)
)
with self.assertRaises(ResourceLimitError) as e: e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403)
self.assertEquals(e.exception.code, 403)
# Ensure does not throw an error # Ensure does not throw an error
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
return_value=defer.succeed(small_number_of_users) self.get_success(self.auth.check_auth_blocking())
)
yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self): def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed # Support users allowed
yield defer.ensureDeferred( self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) self.store.get_monthly_active_count = simple_async_mock(100)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed # Bots not allowed
with self.assertRaises(ResourceLimitError): self.get_failure(
yield defer.ensureDeferred( self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
self.auth.check_auth_blocking(user_type=UserTypes.BOT) )
) self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed # Real users not allowed
with self.assertRaises(ResourceLimitError): self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_reserved_threepid(self): def test_reserved_threepid(self):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1 self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2) self.store.get_monthly_active_count = simple_async_mock(2)
threepid = {"medium": "email", "address": "reserved@server.com"} threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid] self.auth_blocking._mau_limits_reserved_threepids = [threepid]
with self.assertRaises(ResourceLimitError): self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError): self.get_failure(
yield defer.ensureDeferred( self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
self.auth.check_auth_blocking(threepid=unknown_threepid) )
)
yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) self.get_success(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self): def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403)
self.assertEquals(e.exception.code, 403)
@defer.inlineCallbacks
def test_hs_disabled_no_server_notices_user(self): def test_hs_disabled_no_server_notices_user(self):
"""Check that 'hs_disabled_message' works correctly when there is no """Check that 'hs_disabled_message' works correctly when there is no
server_notices user. server_notices user.
@ -422,16 +384,14 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403)
self.assertEquals(e.exception.code, 403)
@defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self): def test_server_notices_mxid_special_cased(self):
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
user = "@user:server" user = "@user:server"
self.auth_blocking._server_notices_mxid = user self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) self.get_success(self.auth.check_auth_blocking(user))

View File

@ -18,15 +18,12 @@
import jsonschema import jsonschema
from twisted.internet import defer
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver
user_localpart = "test_user" user_localpart = "test_user"
@ -39,9 +36,8 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs) return make_event_from_dict(kwargs)
class FilteringTestCase(unittest.TestCase): class FilteringTestCase(unittest.HomeserverTestCase):
def setUp(self): def prepare(self, reactor, clock, hs):
hs = setup_test_homeserver(self.addCleanup)
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
@ -351,10 +347,9 @@ class FilteringTestCase(unittest.TestCase):
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(definition).check(event))
@defer.inlineCallbacks
def test_filter_presence_match(self): def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}} user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = yield defer.ensureDeferred( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
@ -362,7 +357,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile") event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event] events = [event]
user_filter = yield defer.ensureDeferred( user_filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id user_localpart=user_localpart, filter_id=filter_id
) )
@ -371,11 +366,10 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
self.assertEquals(events, results) self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_presence_no_match(self): def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}} user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = yield defer.ensureDeferred( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json user_localpart=user_localpart + "2", user_filter=user_filter_json
) )
@ -387,7 +381,7 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
user_filter = yield defer.ensureDeferred( user_filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id user_localpart=user_localpart + "2", filter_id=filter_id
) )
@ -396,10 +390,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
self.assertEquals([], results) self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self): def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
@ -407,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event] events = [event]
user_filter = yield defer.ensureDeferred( user_filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id user_localpart=user_localpart, filter_id=filter_id
) )
@ -416,10 +409,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_room_state(events=events) results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results) self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
@ -429,7 +421,7 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
user_filter = yield defer.ensureDeferred( user_filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id user_localpart=user_localpart, filter_id=filter_id
) )
@ -454,11 +446,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
@defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred( filter_id = self.get_success(
self.filtering.add_user_filter( self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
@ -468,7 +459,7 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
user_filter_json, user_filter_json,
( (
yield defer.ensureDeferred( self.get_success(
self.datastore.get_user_filter( self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0 user_localpart=user_localpart, filter_id=0
) )
@ -476,17 +467,16 @@ class FilteringTestCase(unittest.TestCase):
), ),
) )
@defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield defer.ensureDeferred( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
) )
filter = yield defer.ensureDeferred( filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id user_localpart=user_localpart, filter_id=filter_id
) )

View File

@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock() self.mock_scheduler = Mock()
hs = Mock() hs = Mock()
hs.get_datastore.return_value = self.mock_store hs.get_datastore.return_value = self.mock_store
self.mock_store.get_received_ts.return_value = defer.succeed(0) self.mock_store.get_received_ts.return_value = make_awaitable(0)
self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested=False), self._mkservice(is_interested=False),
] ]
self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = defer.succeed([]) self.mock_store.get_user_by_id.return_value = make_awaitable([])
event = Mock( event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
) )
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
defer.succeed((0, [event])), make_awaitable((0, [event])),
defer.succeed((0, [])), make_awaitable((0, [])),
] ]
self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)] services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = defer.succeed(None) self.mock_store.get_user_by_id.return_value = make_awaitable(None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
defer.succeed((0, [event])), make_awaitable((0, [event])),
defer.succeed((0, [])), make_awaitable((0, [])),
] ]
self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)] services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id}) self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
defer.succeed((0, [event])), make_awaitable((0, [event])),
defer.succeed((0, [])), make_awaitable((0, [])),
] ]
self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.", "query_user called when it shouldn't have been.",
) )
@defer.inlineCallbacks
def test_query_room_alias_exists(self): def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar" room_alias_str = "#foo:bar"
room_alias = Mock() room_alias = Mock()
@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Mock(room_id=room_id, servers=servers) Mock(room_id=room_id, servers=servers)
) )
result = yield defer.ensureDeferred( result = self.successResultOf(
self.handler.query_room_alias_exists(room_alias) defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
) )
self.mock_as_api.query_alias.assert_called_once_with( self.mock_as_api.query_alias.assert_called_once_with(

View File

@ -14,163 +14,11 @@
# limitations under the License. # limitations under the License.
"""Tests REST events for /profile paths.""" """Tests REST events for /profile paths."""
import json
from mock import Mock
from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, profile, room from synapse.rest.client.v1 import login, profile, room
from tests import unittest from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/r0"
class MockHandlerProfileTestCase(unittest.TestCase):
""" Tests rest layer of profile management.
Todo: move these into ProfileTestCase
"""
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.mock_handler = Mock(
spec=[
"get_displayname",
"set_displayname",
"get_avatar_url",
"set_avatar_url",
"check_profile_query_allowed",
]
)
self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
Mock()
)
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
federation_http_client=None,
resource_for_client=self.mock_resource,
federation=Mock(),
federation_client=Mock(),
profile_handler=self.mock_handler,
)
async def _get_user_by_req(request=None, allow_guest=False):
return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
profile.register_servlets(hs, self.mock_resource)
@defer.inlineCallbacks
def test_get_my_name(self):
mocked_get = self.mock_handler.get_displayname
mocked_get.return_value = defer.succeed("Frank")
(code, response) = yield self.mock_resource.trigger(
"GET", "/profile/%s/displayname" % (myid), None
)
self.assertEquals(200, code)
self.assertEquals({"displayname": "Frank"}, response)
self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
@defer.inlineCallbacks
def test_set_my_name(self):
mocked_set = self.mock_handler.set_displayname
mocked_set.return_value = defer.succeed(())
(code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
)
self.assertEquals(200, code)
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
mocked_set = self.mock_handler.set_displayname
mocked_set.side_effect = AuthError(400, "message")
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/displayname" % ("@4567:test"),
b'{"displayname": "Frank Jr."}',
)
self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
@defer.inlineCallbacks
def test_get_other_name(self):
mocked_get = self.mock_handler.get_displayname
mocked_get.return_value = defer.succeed("Bob")
(code, response) = yield self.mock_resource.trigger(
"GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
)
self.assertEquals(200, code)
self.assertEquals({"displayname": "Bob"}, response)
@defer.inlineCallbacks
def test_set_other_name(self):
mocked_set = self.mock_handler.set_displayname
mocked_set.side_effect = SynapseError(400, "message")
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/displayname" % ("@opaque:elsewhere"),
b'{"displayname":"bob"}',
)
self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
@defer.inlineCallbacks
def test_get_my_avatar(self):
mocked_get = self.mock_handler.get_avatar_url
mocked_get.return_value = defer.succeed("http://my.server/me.png")
(code, response) = yield self.mock_resource.trigger(
"GET", "/profile/%s/avatar_url" % (myid), None
)
self.assertEquals(200, code)
self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
@defer.inlineCallbacks
def test_set_my_avatar(self):
mocked_set = self.mock_handler.set_avatar_url
mocked_set.return_value = defer.succeed(())
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/avatar_url" % (myid),
b'{"avatar_url": "http://my.server/pic.gif"}',
)
self.assertEquals(200, code)
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
class ProfileTestCase(unittest.HomeserverTestCase): class ProfileTestCase(unittest.HomeserverTestCase):
@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.owner = self.register_user("owner", "pass") self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass") self.owner_tok = self.login("owner", "pass")
self.other = self.register_user("other", "pass", displayname="Bob")
def test_get_displayname(self):
res = self._get_displayname()
self.assertEqual(res, "owner")
def test_set_displayname(self): def test_set_displayname(self):
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
"/profile/%s/displayname" % (self.owner,), "/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test"}), content={"displayname": "test"},
access_token=self.owner_tok, access_token=self.owner_tok,
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
res = self.get_displayname() res = self._get_displayname()
self.assertEqual(res, "test") self.assertEqual(res, "test")
def test_set_displayname_noauth(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content={"displayname": "test"},
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_displayname_too_long(self): def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400""" """Attempts to set a stupid displayname should get a 400"""
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
"/profile/%s/displayname" % (self.owner,), "/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test" * 100}), content={"displayname": "test" * 100},
access_token=self.owner_tok, access_token=self.owner_tok,
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
res = self.get_displayname() res = self._get_displayname()
self.assertEqual(res, "owner") self.assertEqual(res, "owner")
def get_displayname(self): def test_get_displayname_other(self):
channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) res = self._get_displayname(self.other)
self.assertEquals(res, "Bob")
def test_set_displayname_other(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.other,),
content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
def test_get_avatar_url(self):
res = self._get_avatar_url()
self.assertIsNone(res)
def test_set_avatar_url(self):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
res = self._get_avatar_url()
self.assertEqual(res, "http://my.server/pic.gif")
def test_set_avatar_url_noauth(self):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif"},
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_avatar_url_too_long(self):
"""Attempts to set a stupid avatar_url should get a 400"""
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
res = self._get_avatar_url()
self.assertIsNone(res)
def test_get_avatar_url_other(self):
res = self._get_avatar_url(self.other)
self.assertIsNone(res)
def test_set_avatar_url_other(self):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.other,),
content={"avatar_url": "http://my.server/pic.gif"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
def _get_displayname(self, name=None):
channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"] return channel.json_body["displayname"]
def _get_avatar_url(self, name=None):
channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body.get("avatar_url")
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):