mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-06 12:04:17 -04:00
Allow additional SSO properties to be passed to the client (#8413)
This commit is contained in:
parent
ceafb5a1c6
commit
8b40843392
9 changed files with 278 additions and 67 deletions
|
@ -21,7 +21,6 @@ from mock import Mock, patch
|
|||
import attr
|
||||
import pymacaroons
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
|
@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
|
|||
async def map_user_attributes(self, userinfo, token):
|
||||
return {"localpart": userinfo["username"], "display_name": None}
|
||||
|
||||
# Do not include get_extra_attributes to test backwards compatibility paths.
|
||||
|
||||
|
||||
class TestMappingProviderExtra(TestMappingProvider):
|
||||
async def get_extra_attributes(self, userinfo, token):
|
||||
return {"phone": userinfo["phone"]}
|
||||
|
||||
|
||||
def simple_async_mock(return_value=None, raises=None):
|
||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
||||
|
@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
config = self.default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
oidc_config = config.get("oidc_config", {})
|
||||
oidc_config = {}
|
||||
oidc_config["enabled"] = True
|
||||
oidc_config["client_id"] = CLIENT_ID
|
||||
oidc_config["client_secret"] = CLIENT_SECRET
|
||||
|
@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
oidc_config["user_mapping_provider"] = {
|
||||
"module": __name__ + ".TestMappingProvider",
|
||||
}
|
||||
|
||||
# Update this config with what's in the default config so that
|
||||
# override_config works as expected.
|
||||
oidc_config.update(config.get("oidc_config", {}))
|
||||
config["oidc_config"] = oidc_config
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
|
@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
|
||||
|
||||
@override_config({"oidc_config": {"discover": True}})
|
||||
@defer.inlineCallbacks
|
||||
def test_discovery(self):
|
||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||
# This would throw if some metadata were invalid
|
||||
metadata = yield defer.ensureDeferred(self.handler.load_metadata())
|
||||
metadata = self.get_success(self.handler.load_metadata())
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
|
||||
self.assertEqual(metadata.issuer, ISSUER)
|
||||
|
@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
# subsequent calls should be cached
|
||||
self.http_client.reset_mock()
|
||||
yield defer.ensureDeferred(self.handler.load_metadata())
|
||||
self.get_success(self.handler.load_metadata())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": COMMON_CONFIG})
|
||||
@defer.inlineCallbacks
|
||||
def test_no_discovery(self):
|
||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||
yield defer.ensureDeferred(self.handler.load_metadata())
|
||||
self.get_success(self.handler.load_metadata())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": COMMON_CONFIG})
|
||||
@defer.inlineCallbacks
|
||||
def test_load_jwks(self):
|
||||
"""JWKS loading is done once (then cached) if used."""
|
||||
jwks = yield defer.ensureDeferred(self.handler.load_jwks())
|
||||
jwks = self.get_success(self.handler.load_jwks())
|
||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||
self.assertEqual(jwks, {"keys": []})
|
||||
|
||||
# subsequent calls should be cached…
|
||||
self.http_client.reset_mock()
|
||||
yield defer.ensureDeferred(self.handler.load_jwks())
|
||||
self.get_success(self.handler.load_jwks())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
# …unless forced
|
||||
self.http_client.reset_mock()
|
||||
yield defer.ensureDeferred(self.handler.load_jwks(force=True))
|
||||
self.get_success(self.handler.load_jwks(force=True))
|
||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||
|
||||
# Throw if the JWKS uri is missing
|
||||
with self.metadata_edit({"jwks_uri": None}):
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield defer.ensureDeferred(self.handler.load_jwks(force=True))
|
||||
self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
|
||||
|
||||
# Return empty key set if JWKS are not used
|
||||
self.handler._scopes = [] # not asking the openid scope
|
||||
self.http_client.get_json.reset_mock()
|
||||
jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
|
||||
jwks = self.get_success(self.handler.load_jwks(force=True))
|
||||
self.http_client.get_json.assert_not_called()
|
||||
self.assertEqual(jwks, {"keys": []})
|
||||
|
||||
|
@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
# This should not throw
|
||||
self.handler._validate_metadata()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_redirect_request(self):
|
||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||
req = Mock(spec=["addCookie"])
|
||||
url = yield defer.ensureDeferred(
|
||||
url = self.get_success(
|
||||
self.handler.handle_redirect_request(req, b"http://client/redirect")
|
||||
)
|
||||
url = urlparse(url)
|
||||
|
@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
self.assertEqual(params["nonce"], [nonce])
|
||||
self.assertEqual(redirect, "http://client/redirect")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_callback_error(self):
|
||||
"""Errors from the provider returned in the callback are displayed."""
|
||||
self.handler._render_error = Mock()
|
||||
request = Mock(args={})
|
||||
request.args[b"error"] = [b"invalid_client"]
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_client", "")
|
||||
|
||||
request.args[b"error_description"] = [b"some description"]
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_client", "some description")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_callback(self):
|
||||
"""Code callback works and display errors if something went wrong.
|
||||
|
||||
|
@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
"sub": "foo",
|
||||
"preferred_username": "bar",
|
||||
}
|
||||
user_id = UserID("foo", "domain.org")
|
||||
user_id = "@foo:domain.org"
|
||||
self.handler._render_error = Mock(return_value=None)
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
|
@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
client_redirect_url = "http://client/redirect"
|
||||
user_agent = "Browser"
|
||||
ip_address = "10.0.0.1"
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
request.getCookie.return_value = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
request.getCookie.return_value = session
|
||||
|
||||
request.args = {}
|
||||
request.args[b"code"] = [code.encode("utf-8")]
|
||||
|
@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
|
||||
request.getClientIP.return_value = ip_address
|
||||
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user_id, request, client_redirect_url,
|
||||
user_id, request, client_redirect_url, {},
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||
|
@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
self.handler._map_userinfo_to_user = simple_async_mock(
|
||||
raises=MappingException()
|
||||
)
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("mapping_error")
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||
|
||||
# Handle ID token errors
|
||||
self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_token")
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.reset_mock()
|
||||
|
@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
# With userinfo fetching
|
||||
self.handler._scopes = [] # do not ask the "openid" scope
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user_id, request, client_redirect_url,
|
||||
user_id, request, client_redirect_url, {},
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_not_called()
|
||||
|
@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
# Handle userinfo fetching error
|
||||
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("fetch_error")
|
||||
|
||||
# Handle code exchange failure
|
||||
self.handler._exchange_code = simple_async_mock(
|
||||
raises=OidcError("invalid_request")
|
||||
)
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_request")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_callback_session(self):
|
||||
"""The callback verifies the session presence and validity"""
|
||||
self.handler._render_error = Mock(return_value=None)
|
||||
|
@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
# Missing cookie
|
||||
request.args = {}
|
||||
request.getCookie.return_value = None
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("missing_session", "No session cookie found")
|
||||
|
||||
# Missing session parameter
|
||||
request.args = {}
|
||||
request.getCookie.return_value = "session"
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_request", "State parameter is missing")
|
||||
|
||||
# Invalid cookie
|
||||
request.args = {}
|
||||
request.args[b"state"] = [b"state"]
|
||||
request.getCookie.return_value = "session"
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_session")
|
||||
|
||||
# Mismatching session
|
||||
|
@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
request.args = {}
|
||||
request.args[b"state"] = [b"mismatching state"]
|
||||
request.getCookie.return_value = session
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("mismatching_session")
|
||||
|
||||
# Valid session
|
||||
request.args = {}
|
||||
request.args[b"state"] = [b"state"]
|
||||
request.getCookie.return_value = session
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_request")
|
||||
|
||||
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
|
||||
@defer.inlineCallbacks
|
||||
def test_exchange_code(self):
|
||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||
token = {"type": "bearer"}
|
||||
|
@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
||||
)
|
||||
code = "code"
|
||||
ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
|
||||
ret = self.get_success(self.handler._exchange_code(code))
|
||||
kwargs = self.http_client.request.call_args[1]
|
||||
|
||||
self.assertEqual(ret, token)
|
||||
|
@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
body=b'{"error": "foo", "error_description": "bar"}',
|
||||
)
|
||||
)
|
||||
with self.assertRaises(OidcError) as exc:
|
||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
||||
self.assertEqual(exc.exception.error, "foo")
|
||||
self.assertEqual(exc.exception.error_description, "bar")
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "foo")
|
||||
self.assertEqual(exc.value.error_description, "bar")
|
||||
|
||||
# Internal server error with no JSON body
|
||||
self.http_client.request = simple_async_mock(
|
||||
|
@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
||||
)
|
||||
)
|
||||
with self.assertRaises(OidcError) as exc:
|
||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
||||
self.assertEqual(exc.exception.error, "server_error")
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "server_error")
|
||||
|
||||
# Internal server error with JSON body
|
||||
self.http_client.request = simple_async_mock(
|
||||
|
@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
body=b'{"error": "internal_server_error"}',
|
||||
)
|
||||
)
|
||||
with self.assertRaises(OidcError) as exc:
|
||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
||||
self.assertEqual(exc.exception.error, "internal_server_error")
|
||||
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "internal_server_error")
|
||||
|
||||
# 4xx error without "error" field
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
||||
)
|
||||
with self.assertRaises(OidcError) as exc:
|
||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
||||
self.assertEqual(exc.exception.error, "server_error")
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "server_error")
|
||||
|
||||
# 2xx error with "error" field
|
||||
self.http_client.request = simple_async_mock(
|
||||
|
@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
||||
)
|
||||
)
|
||||
with self.assertRaises(OidcError) as exc:
|
||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
||||
self.assertEqual(exc.exception.error, "some_error")
|
||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||
self.assertEqual(exc.value.error, "some_error")
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"oidc_config": {
|
||||
"user_mapping_provider": {
|
||||
"module": __name__ + ".TestMappingProviderExtra"
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_extra_attributes(self):
|
||||
"""
|
||||
Login while using a mapping provider that implements get_extra_attributes.
|
||||
"""
|
||||
token = {
|
||||
"type": "bearer",
|
||||
"id_token": "id_token",
|
||||
"access_token": "access_token",
|
||||
}
|
||||
userinfo = {
|
||||
"sub": "foo",
|
||||
"phone": "1234567",
|
||||
}
|
||||
user_id = "@foo:domain.org"
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
||||
request = Mock(
|
||||
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
|
||||
)
|
||||
|
||||
state = "state"
|
||||
client_redirect_url = "http://client/redirect"
|
||||
request.getCookie.return_value = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce="nonce",
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
|
||||
request.args = {}
|
||||
request.args[b"code"] = [b"code"]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
|
||||
request.requestHeaders = Mock(spec=["getRawHeaders"])
|
||||
request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
|
||||
request.getClientIP.return_value = "10.0.0.1"
|
||||
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user_id, request, client_redirect_url, {"phone": "1234567"},
|
||||
)
|
||||
|
||||
def test_map_userinfo_to_user(self):
|
||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue