Split OidcProvider out of OidcHandler (#9107)

The idea here is that we will have an instance of OidcProvider for each
configured IdP, with OidcHandler just doing the marshalling of them.

For now it's still hardcoded with a single provider.
This commit is contained in:
Richard van der Hoff 2021-01-14 13:29:17 +00:00 committed by GitHub
parent 12702be951
commit 21a296cd5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 197 additions and 144 deletions

1
changelog.d/9107.feature Normal file
View File

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View File

@ -429,7 +429,6 @@ def setup(config_options):
oidc = hs.get_oidc_handler() oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid. # Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata() await oidc.load_metadata()
await oidc.load_jwks()
await _base.start(hs, config.listeners) await _base.start(hs, config.listeners)

View File

@ -35,6 +35,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.oidc_config import OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -70,6 +71,131 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]}) JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
provider_conf = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider_conf is not None
self._token_generator = OidcSessionTokenGenerator(hs)
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need.
"""
await self._provider.load_metadata()
await self._provider.load_jwks()
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
Once we know the session is legit, we then delegate to the OIDC Provider
implementation, which will exchange the code with the provider and complete the
login/authentication.
Args:
request: the incoming request from the browser.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
code = request.args[b"code"][0].decode()
await self._provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception): class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint """Used to catch errors when calling the token_endpoint
""" """
@ -84,21 +210,25 @@ class OidcError(Exception):
return self.error return self.error
class OidcHandler: class OidcProvider:
"""Handles requests related to the OpenID Connect login flow. """Wraps the config for a single OIDC IdentityProvider
Provides methods for handling redirect requests and callbacks via that particular
IdP.
""" """
def __init__(self, hs: "HomeServer"): def __init__(
self,
hs: "HomeServer",
token_generator: "OidcSessionTokenGenerator",
provider: OidcProviderConfig,
):
self._store = hs.get_datastore() self._store = hs.get_datastore()
self._token_generator = OidcSessionTokenGenerator(hs) self._token_generator = token_generator
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
provider = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider is not None
self._scopes = provider.scopes self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth( self._client_auth = ClientAuth(
@ -552,22 +682,16 @@ class OidcHandler:
nonce=nonce, nonce=nonce,
) )
async def handle_oidc_callback(self, request: SynapseRequest) -> None: async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None:
"""Handle an incoming request to /_synapse/oidc/callback """Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly By this time we have already validated the session on the synapse side, and
way, we don't raise SynapseError from here. Instead, we call now need to do the provider-specific operations. This includes:
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here: - exchange the code with the provider using the ``token_endpoint`` (see
``_exchange_code``)
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
- once we known this session is legit, exchange the code with the
provider using the ``token_endpoint`` (see ``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from - once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token`` the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint`` to fetch UserInfo from the ``userinfo_endpoint``
@ -577,86 +701,12 @@ class OidcHandler:
Args: Args:
request: the incoming request from the browser. request: the incoming request from the browser.
session_data: the session data, extracted from our cookie
code: The authorization code we got from the callback.
""" """
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider # Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
logger.debug("Exchanging code")
code = request.args[b"code"][0].decode()
try: try:
logger.debug("Exchanging code")
token = await self._exchange_code(code) token = await self._exchange_code(code)
except OidcError as e: except OidcError as e:
logger.exception("Could not exchange code") logger.exception("Could not exchange code")

View File

@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client) hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler() self.handler = hs.get_oidc_handler()
self.provider = self.handler._provider
sso_handler = hs.get_sso_handler() sso_handler = hs.get_sso_handler()
# Mock the render error method. # Mock the render error method.
self.render_error = Mock(return_value=None) self.render_error = Mock(return_value=None)
@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs return hs
def metadata_edit(self, values): def metadata_edit(self, values):
return patch.dict(self.handler._provider_metadata, values) return patch.dict(self.provider._provider_metadata, values)
def assertRenderedError(self, error, error_description=None): def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
args = self.render_error.call_args[0] args = self.render_error.call_args[0]
self.assertEqual(args[1], error) self.assertEqual(args[1], error)
if error_description is not None: if error_description is not None:
@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_config(self): def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.handler._callback_url, CALLBACK_URL) self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}}) @override_config({"oidc_config": {"discover": True}})
def test_discovery(self): def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document.""" """The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid # This would throw if some metadata were invalid
metadata = self.get_success(self.handler.load_metadata()) metadata = self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER) self.assertEqual(metadata.issuer, ISSUER)
@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached # subsequent calls should be cached
self.http_client.reset_mock() self.http_client.reset_mock()
self.get_success(self.handler.load_metadata()) self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG}) @override_config({"oidc_config": COMMON_CONFIG})
def test_no_discovery(self): def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document.""" """When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.handler.load_metadata()) self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG}) @override_config({"oidc_config": COMMON_CONFIG})
def test_load_jwks(self): def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used.""" """JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.handler.load_jwks()) jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI) self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []}) self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached… # subsequent calls should be cached…
self.http_client.reset_mock() self.http_client.reset_mock()
self.get_success(self.handler.load_jwks()) self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
# …unless forced # …unless forced
self.http_client.reset_mock() self.http_client.reset_mock()
self.get_success(self.handler.load_jwks(force=True)) self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI) self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing # Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}): with self.metadata_edit({"jwks_uri": None}):
self.get_failure(self.handler.load_jwks(force=True), RuntimeError) self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used # Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock() self.http_client.get_json.reset_mock()
jwks = self.get_success(self.handler.load_jwks(force=True)) jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []}) self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": COMMON_CONFIG}) @override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self): def test_validate_config(self):
"""Provider metadatas are extensively validated.""" """Provider metadatas are extensively validated."""
h = self.handler h = self.provider
# Default test config does not throw # Default test config does not throw
h._validate_metadata() h._validate_metadata()
@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadata validation can be disabled by config.""" """Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}): with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw # This should not throw
self.handler._validate_metadata() self.provider._validate_metadata()
def test_redirect_request(self): def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie.""" """The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"]) req = Mock(spec=["addCookie"])
url = self.get_success( url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect") self.provider.handle_redirect_request(req, b"http://client/redirect")
) )
url = urlparse(url) url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# ensure that we are correctly testing the fallback when "get_extra_attributes" # ensure that we are correctly testing the fallback when "get_extra_attributes"
# is not implemented. # is not implemented.
mapping_provider = self.handler._user_mapping_provider mapping_provider = self.provider._user_mapping_provider
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes _ = mapping_provider.get_extra_attributes
@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username, "username": username,
} }
expected_user_id = "@%s:%s" % (username, self.hs.hostname) expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None, expected_user_id, request, client_redirect_url, None,
) )
self.handler._exchange_code.assert_called_once_with(code) self.provider._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.handler._fetch_userinfo.assert_not_called() self.provider._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called() self.render_error.assert_not_called()
# Handle mapping errors # Handle mapping errors
with patch.object( with patch.object(
self.handler, self.provider,
"_remote_id_from_userinfo", "_remote_id_from_userinfo",
new=Mock(side_effect=MappingException()), new=Mock(side_effect=MappingException()),
): ):
@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error") self.assertRenderedError("mapping_error")
# Handle ID token errors # Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception()) self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token") self.assertRenderedError("invalid_token")
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock() self.provider._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock() self.provider._parse_id_token.reset_mock()
self.handler._fetch_userinfo.reset_mock() self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching # With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None, expected_user_id, request, client_redirect_url, None,
) )
self.handler._exchange_code.assert_called_once_with(code) self.provider._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called() self.provider._parse_id_token.assert_not_called()
self.handler._fetch_userinfo.assert_called_once_with(token) self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called() self.render_error.assert_not_called()
# Handle userinfo fetching error # Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error") self.assertRenderedError("fetch_error")
# Handle code exchange failure # Handle code exchange failure
from synapse.handlers.oidc_handler import OidcError from synapse.handlers.oidc_handler import OidcError
self.handler._exchange_code = simple_async_mock( self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request") raises=OidcError("invalid_request")
) )
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
) )
code = "code" code = "code"
ret = self.get_success(self.handler._exchange_code(code)) ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.http_client.request.call_args[1] kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token) self.assertEqual(ret, token)
@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
from synapse.handlers.oidc_handler import OidcError from synapse.handlers.oidc_handler import OidcError
exc = self.get_failure(self.handler._exchange_code(code), OidcError) exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar") self.assertEqual(exc.value.error_description, "bar")
@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON", code=500, phrase=b"Internal Server Error", body=b"Not JSON",
) )
) )
exc = self.get_failure(self.handler._exchange_code(code), OidcError) exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error") self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body # Internal server error with JSON body
@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
) )
exc = self.get_failure(self.handler._exchange_code(code), OidcError) exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error") self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field # 4xx error without "error" field
self.http_client.request = simple_async_mock( self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
) )
exc = self.get_failure(self.handler._exchange_code(code), OidcError) exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error") self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field # 2xx error with "error" field
@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}', code=200, phrase=b"OK", body=b'{"error": "some_error"}',
) )
) )
exc = self.get_failure(self.handler._exchange_code(code), OidcError) exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error") self.assertEqual(exc.value.error, "some_error")
@override_config( @override_config(
@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo", "username": "foo",
"phone": "1234567", "phone": "1234567",
} }
self.handler._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
from synapse.handlers.oidc_handler import OidcSessionData from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler() handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={}) provider = handler._provider
handler._parse_id_token = simple_async_mock(return_value=userinfo) provider._exchange_code = simple_async_mock(return_value={})
handler._fetch_userinfo = simple_async_mock(return_value=userinfo) provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state" state = "state"
session = handler._token_generator.generate_oidc_session_token( session = handler._token_generator.generate_oidc_session_token(