mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 22:14:55 -04:00
Convert auth handler to async/await (#7261)
This commit is contained in:
parent
17a2433b0d
commit
eed7c5b89e
10 changed files with 224 additions and 170 deletions
|
@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_user_bad_token(self):
|
||||
|
@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.getClientIP.return_value = "192.168.10.10"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
|
||||
|
@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||
self.assertEquals(
|
||||
requester.user.to_string(), masquerading_user_id.decode("utf8")
|
||||
)
|
||||
|
@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
user_info = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
)
|
||||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
|
||||
|
@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("guest = true")
|
||||
serialized = macaroon.serialize()
|
||||
|
||||
user_info = yield self.auth.get_user_by_access_token(serialized)
|
||||
user_info = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_access_token(serialized)
|
||||
)
|
||||
user = user_info["user"]
|
||||
is_guest = user_info["is_guest"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
|
@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_cannot_use_regular_token_as_guest(self):
|
||||
USER_ID = "@percy:matrix.org"
|
||||
self.store.add_access_token_to_user = Mock()
|
||||
self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
|
||||
self.store.get_device = Mock(return_value=defer.succeed(None))
|
||||
|
||||
token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
|
||||
USER_ID, "DEVICE", valid_until_ms=None
|
||||
token = yield defer.ensureDeferred(
|
||||
self.hs.handlers.auth_handler.get_access_token_for_user_id(
|
||||
USER_ID, "DEVICE", valid_until_ms=None
|
||||
)
|
||||
)
|
||||
self.store.add_access_token_to_user.assert_called_with(
|
||||
USER_ID, token, "DEVICE", None
|
||||
|
@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [token.encode("ascii")]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_req(request, allow_guest=True)
|
||||
)
|
||||
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||
self.assertFalse(requester.is_guest)
|
||||
|
||||
|
@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
with self.assertRaises(InvalidClientCredentialsError) as cm:
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_req(request, allow_guest=True)
|
||||
)
|
||||
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
|
||||
|
@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
small_number_of_users = 1
|
||||
|
||||
# Ensure no error thrown
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
|
||||
|
@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(small_number_of_users)
|
||||
)
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_blocking_mau__depending_on_user_type(self):
|
||||
|
@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
|
||||
# Support users allowed
|
||||
yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
|
||||
yield defer.ensureDeferred(
|
||||
self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
|
||||
)
|
||||
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
|
||||
# Bots not allowed
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
|
||||
yield defer.ensureDeferred(
|
||||
self.auth.check_auth_blocking(user_type=UserTypes.BOT)
|
||||
)
|
||||
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
|
||||
# Real users not allowed
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_reserved_threepid(self):
|
||||
|
@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase):
|
|||
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
|
||||
self.hs.config.mau_limits_reserved_threepids = [threepid]
|
||||
|
||||
yield self.store.register_user(user_id="user1", password_hash=None)
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield self.auth.check_auth_blocking(threepid=unknown_threepid)
|
||||
yield defer.ensureDeferred(
|
||||
self.auth.check_auth_blocking(threepid=unknown_threepid)
|
||||
)
|
||||
|
||||
yield self.auth.check_auth_blocking(threepid=threepid)
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_hs_disabled(self):
|
||||
self.hs.config.hs_disabled = True
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.hs.config.hs_disabled = True
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase):
|
|||
user = "@user:server"
|
||||
self.hs.config.server_notices_mxid = user
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
yield self.auth.check_auth_blocking(user)
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue