mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-12 15:44:20 -05:00
Split OidcProvider out of OidcHandler (#9107)
The idea here is that we will have an instance of OidcProvider for each configured IdP, with OidcHandler just doing the marshalling of them. For now it's still hardcoded with a single provider.
This commit is contained in:
parent
12702be951
commit
21a296cd5a
1
changelog.d/9107.feature
Normal file
1
changelog.d/9107.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add support for multiple SSO Identity Providers.
|
@ -429,7 +429,6 @@ def setup(config_options):
|
|||||||
oidc = hs.get_oidc_handler()
|
oidc = hs.get_oidc_handler()
|
||||||
# Loading the provider metadata also ensures the provider config is valid.
|
# Loading the provider metadata also ensures the provider config is valid.
|
||||||
await oidc.load_metadata()
|
await oidc.load_metadata()
|
||||||
await oidc.load_jwks()
|
|
||||||
|
|
||||||
await _base.start(hs, config.listeners)
|
await _base.start(hs, config.listeners)
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from typing_extensions import TypedDict
|
|||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
|
from synapse.config.oidc_config import OidcProviderConfig
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
@ -70,6 +71,131 @@ JWK = Dict[str, str]
|
|||||||
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
||||||
|
|
||||||
|
|
||||||
|
class OidcHandler:
|
||||||
|
"""Handles requests related to the OpenID Connect login flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
provider_conf = hs.config.oidc.oidc_provider
|
||||||
|
# we should not have been instantiated if there is no configured provider.
|
||||||
|
assert provider_conf is not None
|
||||||
|
|
||||||
|
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||||
|
|
||||||
|
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
|
||||||
|
|
||||||
|
async def load_metadata(self) -> None:
|
||||||
|
"""Validate the config and load the metadata from the remote endpoint.
|
||||||
|
|
||||||
|
Called at startup to ensure we have everything we need.
|
||||||
|
"""
|
||||||
|
await self._provider.load_metadata()
|
||||||
|
await self._provider.load_jwks()
|
||||||
|
|
||||||
|
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||||
|
"""Handle an incoming request to /_synapse/oidc/callback
|
||||||
|
|
||||||
|
Since we might want to display OIDC-related errors in a user-friendly
|
||||||
|
way, we don't raise SynapseError from here. Instead, we call
|
||||||
|
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||||
|
|
||||||
|
Most of the OpenID Connect logic happens here:
|
||||||
|
|
||||||
|
- first, we check if there was any error returned by the provider and
|
||||||
|
display it
|
||||||
|
- then we fetch the session cookie, decode and verify it
|
||||||
|
- the ``state`` query parameter should match with the one stored in the
|
||||||
|
session cookie
|
||||||
|
|
||||||
|
Once we know the session is legit, we then delegate to the OIDC Provider
|
||||||
|
implementation, which will exchange the code with the provider and complete the
|
||||||
|
login/authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: the incoming request from the browser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The provider might redirect with an error.
|
||||||
|
# In that case, just display it as-is.
|
||||||
|
if b"error" in request.args:
|
||||||
|
# error response from the auth server. see:
|
||||||
|
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
||||||
|
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||||
|
error = request.args[b"error"][0].decode()
|
||||||
|
description = request.args.get(b"error_description", [b""])[0].decode()
|
||||||
|
|
||||||
|
# Most of the errors returned by the provider could be due by
|
||||||
|
# either the provider misbehaving or Synapse being misconfigured.
|
||||||
|
# The only exception of that is "access_denied", where the user
|
||||||
|
# probably cancelled the login flow. In other cases, log those errors.
|
||||||
|
if error != "access_denied":
|
||||||
|
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||||
|
|
||||||
|
self._sso_handler.render_error(request, error, description)
|
||||||
|
return
|
||||||
|
|
||||||
|
# otherwise, it is presumably a successful response. see:
|
||||||
|
# https://tools.ietf.org/html/rfc6749#section-4.1.2
|
||||||
|
|
||||||
|
# Fetch the session cookie
|
||||||
|
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||||
|
if session is None:
|
||||||
|
logger.info("No session cookie found")
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request, "missing_session", "No session cookie found"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove the cookie. There is a good chance that if the callback failed
|
||||||
|
# once, it will fail next time and the code will already be exchanged.
|
||||||
|
# Removing it early avoids spamming the provider with token requests.
|
||||||
|
request.addCookie(
|
||||||
|
SESSION_COOKIE_NAME,
|
||||||
|
b"",
|
||||||
|
path="/_synapse/oidc",
|
||||||
|
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
||||||
|
httpOnly=True,
|
||||||
|
sameSite="lax",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for the state query parameter
|
||||||
|
if b"state" not in request.args:
|
||||||
|
logger.info("State parameter is missing")
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request, "invalid_request", "State parameter is missing"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
state = request.args[b"state"][0].decode()
|
||||||
|
|
||||||
|
# Deserialize the session token and verify it.
|
||||||
|
try:
|
||||||
|
session_data = self._token_generator.verify_oidc_session_token(
|
||||||
|
session, state
|
||||||
|
)
|
||||||
|
except MacaroonDeserializationException as e:
|
||||||
|
logger.exception("Invalid session")
|
||||||
|
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||||
|
return
|
||||||
|
except MacaroonInvalidSignatureException as e:
|
||||||
|
logger.exception("Could not verify session")
|
||||||
|
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
if b"code" not in request.args:
|
||||||
|
logger.info("Code parameter is missing")
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request, "invalid_request", "Code parameter is missing"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
code = request.args[b"code"][0].decode()
|
||||||
|
|
||||||
|
await self._provider.handle_oidc_callback(request, session_data, code)
|
||||||
|
|
||||||
|
|
||||||
class OidcError(Exception):
|
class OidcError(Exception):
|
||||||
"""Used to catch errors when calling the token_endpoint
|
"""Used to catch errors when calling the token_endpoint
|
||||||
"""
|
"""
|
||||||
@ -84,21 +210,25 @@ class OidcError(Exception):
|
|||||||
return self.error
|
return self.error
|
||||||
|
|
||||||
|
|
||||||
class OidcHandler:
|
class OidcProvider:
|
||||||
"""Handles requests related to the OpenID Connect login flow.
|
"""Wraps the config for a single OIDC IdentityProvider
|
||||||
|
|
||||||
|
Provides methods for handling redirect requests and callbacks via that particular
|
||||||
|
IdP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
token_generator: "OidcSessionTokenGenerator",
|
||||||
|
provider: OidcProviderConfig,
|
||||||
|
):
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
|
|
||||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
self._token_generator = token_generator
|
||||||
|
|
||||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||||
|
|
||||||
provider = hs.config.oidc.oidc_provider
|
|
||||||
# we should not have been instantiated if there is no configured provider.
|
|
||||||
assert provider is not None
|
|
||||||
|
|
||||||
self._scopes = provider.scopes
|
self._scopes = provider.scopes
|
||||||
self._user_profile_method = provider.user_profile_method
|
self._user_profile_method = provider.user_profile_method
|
||||||
self._client_auth = ClientAuth(
|
self._client_auth = ClientAuth(
|
||||||
@ -552,22 +682,16 @@ class OidcHandler:
|
|||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
async def handle_oidc_callback(
|
||||||
|
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
|
||||||
|
) -> None:
|
||||||
"""Handle an incoming request to /_synapse/oidc/callback
|
"""Handle an incoming request to /_synapse/oidc/callback
|
||||||
|
|
||||||
Since we might want to display OIDC-related errors in a user-friendly
|
By this time we have already validated the session on the synapse side, and
|
||||||
way, we don't raise SynapseError from here. Instead, we call
|
now need to do the provider-specific operations. This includes:
|
||||||
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
|
||||||
|
|
||||||
Most of the OpenID Connect logic happens here:
|
- exchange the code with the provider using the ``token_endpoint`` (see
|
||||||
|
``_exchange_code``)
|
||||||
- first, we check if there was any error returned by the provider and
|
|
||||||
display it
|
|
||||||
- then we fetch the session cookie, decode and verify it
|
|
||||||
- the ``state`` query parameter should match with the one stored in the
|
|
||||||
session cookie
|
|
||||||
- once we known this session is legit, exchange the code with the
|
|
||||||
provider using the ``token_endpoint`` (see ``_exchange_code``)
|
|
||||||
- once we have the token, use it to either extract the UserInfo from
|
- once we have the token, use it to either extract the UserInfo from
|
||||||
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
|
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
|
||||||
to fetch UserInfo from the ``userinfo_endpoint``
|
to fetch UserInfo from the ``userinfo_endpoint``
|
||||||
@ -577,86 +701,12 @@ class OidcHandler:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: the incoming request from the browser.
|
request: the incoming request from the browser.
|
||||||
|
session_data: the session data, extracted from our cookie
|
||||||
|
code: The authorization code we got from the callback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The provider might redirect with an error.
|
|
||||||
# In that case, just display it as-is.
|
|
||||||
if b"error" in request.args:
|
|
||||||
# error response from the auth server. see:
|
|
||||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
|
||||||
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
|
||||||
error = request.args[b"error"][0].decode()
|
|
||||||
description = request.args.get(b"error_description", [b""])[0].decode()
|
|
||||||
|
|
||||||
# Most of the errors returned by the provider could be due by
|
|
||||||
# either the provider misbehaving or Synapse being misconfigured.
|
|
||||||
# The only exception of that is "access_denied", where the user
|
|
||||||
# probably cancelled the login flow. In other cases, log those errors.
|
|
||||||
if error != "access_denied":
|
|
||||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
|
||||||
|
|
||||||
self._sso_handler.render_error(request, error, description)
|
|
||||||
return
|
|
||||||
|
|
||||||
# otherwise, it is presumably a successful response. see:
|
|
||||||
# https://tools.ietf.org/html/rfc6749#section-4.1.2
|
|
||||||
|
|
||||||
# Fetch the session cookie
|
|
||||||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
|
||||||
if session is None:
|
|
||||||
logger.info("No session cookie found")
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request, "missing_session", "No session cookie found"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove the cookie. There is a good chance that if the callback failed
|
|
||||||
# once, it will fail next time and the code will already be exchanged.
|
|
||||||
# Removing it early avoids spamming the provider with token requests.
|
|
||||||
request.addCookie(
|
|
||||||
SESSION_COOKIE_NAME,
|
|
||||||
b"",
|
|
||||||
path="/_synapse/oidc",
|
|
||||||
expires="Thu, Jan 01 1970 00:00:00 UTC",
|
|
||||||
httpOnly=True,
|
|
||||||
sameSite="lax",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for the state query parameter
|
|
||||||
if b"state" not in request.args:
|
|
||||||
logger.info("State parameter is missing")
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request, "invalid_request", "State parameter is missing"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
state = request.args[b"state"][0].decode()
|
|
||||||
|
|
||||||
# Deserialize the session token and verify it.
|
|
||||||
try:
|
|
||||||
session_data = self._token_generator.verify_oidc_session_token(
|
|
||||||
session, state
|
|
||||||
)
|
|
||||||
except MacaroonDeserializationException as e:
|
|
||||||
logger.exception("Invalid session")
|
|
||||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
|
||||||
return
|
|
||||||
except MacaroonInvalidSignatureException as e:
|
|
||||||
logger.exception("Could not verify session")
|
|
||||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Exchange the code with the provider
|
# Exchange the code with the provider
|
||||||
if b"code" not in request.args:
|
|
||||||
logger.info("Code parameter is missing")
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request, "invalid_request", "Code parameter is missing"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("Exchanging code")
|
|
||||||
code = request.args[b"code"][0].decode()
|
|
||||||
try:
|
try:
|
||||||
|
logger.debug("Exchanging code")
|
||||||
token = await self._exchange_code(code)
|
token = await self._exchange_code(code)
|
||||||
except OidcError as e:
|
except OidcError as e:
|
||||||
logger.exception("Could not exchange code")
|
logger.exception("Could not exchange code")
|
||||||
|
@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||||
|
|
||||||
self.handler = hs.get_oidc_handler()
|
self.handler = hs.get_oidc_handler()
|
||||||
|
self.provider = self.handler._provider
|
||||||
sso_handler = hs.get_sso_handler()
|
sso_handler = hs.get_sso_handler()
|
||||||
# Mock the render error method.
|
# Mock the render error method.
|
||||||
self.render_error = Mock(return_value=None)
|
self.render_error = Mock(return_value=None)
|
||||||
@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
return hs
|
return hs
|
||||||
|
|
||||||
def metadata_edit(self, values):
|
def metadata_edit(self, values):
|
||||||
return patch.dict(self.handler._provider_metadata, values)
|
return patch.dict(self.provider._provider_metadata, values)
|
||||||
|
|
||||||
def assertRenderedError(self, error, error_description=None):
|
def assertRenderedError(self, error, error_description=None):
|
||||||
|
self.render_error.assert_called_once()
|
||||||
args = self.render_error.call_args[0]
|
args = self.render_error.call_args[0]
|
||||||
self.assertEqual(args[1], error)
|
self.assertEqual(args[1], error)
|
||||||
if error_description is not None:
|
if error_description is not None:
|
||||||
@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||||
self.assertEqual(self.handler._callback_url, CALLBACK_URL)
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||||
self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||||
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"discover": True}})
|
@override_config({"oidc_config": {"discover": True}})
|
||||||
def test_discovery(self):
|
def test_discovery(self):
|
||||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||||
# This would throw if some metadata were invalid
|
# This would throw if some metadata were invalid
|
||||||
metadata = self.get_success(self.handler.load_metadata())
|
metadata = self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||||
|
|
||||||
self.assertEqual(metadata.issuer, ISSUER)
|
self.assertEqual(metadata.issuer, ISSUER)
|
||||||
@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# subsequent calls should be cached
|
# subsequent calls should be cached
|
||||||
self.http_client.reset_mock()
|
self.http_client.reset_mock()
|
||||||
self.get_success(self.handler.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": COMMON_CONFIG})
|
||||||
def test_no_discovery(self):
|
def test_no_discovery(self):
|
||||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||||
self.get_success(self.handler.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": COMMON_CONFIG})
|
||||||
def test_load_jwks(self):
|
def test_load_jwks(self):
|
||||||
"""JWKS loading is done once (then cached) if used."""
|
"""JWKS loading is done once (then cached) if used."""
|
||||||
jwks = self.get_success(self.handler.load_jwks())
|
jwks = self.get_success(self.provider.load_jwks())
|
||||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||||
self.assertEqual(jwks, {"keys": []})
|
self.assertEqual(jwks, {"keys": []})
|
||||||
|
|
||||||
# subsequent calls should be cached…
|
# subsequent calls should be cached…
|
||||||
self.http_client.reset_mock()
|
self.http_client.reset_mock()
|
||||||
self.get_success(self.handler.load_jwks())
|
self.get_success(self.provider.load_jwks())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
# …unless forced
|
# …unless forced
|
||||||
self.http_client.reset_mock()
|
self.http_client.reset_mock()
|
||||||
self.get_success(self.handler.load_jwks(force=True))
|
self.get_success(self.provider.load_jwks(force=True))
|
||||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||||
|
|
||||||
# Throw if the JWKS uri is missing
|
# Throw if the JWKS uri is missing
|
||||||
with self.metadata_edit({"jwks_uri": None}):
|
with self.metadata_edit({"jwks_uri": None}):
|
||||||
self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
|
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||||
|
|
||||||
# Return empty key set if JWKS are not used
|
# Return empty key set if JWKS are not used
|
||||||
self.handler._scopes = [] # not asking the openid scope
|
self.provider._scopes = [] # not asking the openid scope
|
||||||
self.http_client.get_json.reset_mock()
|
self.http_client.get_json.reset_mock()
|
||||||
jwks = self.get_success(self.handler.load_jwks(force=True))
|
jwks = self.get_success(self.provider.load_jwks(force=True))
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
self.assertEqual(jwks, {"keys": []})
|
self.assertEqual(jwks, {"keys": []})
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": COMMON_CONFIG})
|
||||||
def test_validate_config(self):
|
def test_validate_config(self):
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
h = self.handler
|
h = self.provider
|
||||||
|
|
||||||
# Default test config does not throw
|
# Default test config does not throw
|
||||||
h._validate_metadata()
|
h._validate_metadata()
|
||||||
@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
"""Provider metadata validation can be disabled by config."""
|
"""Provider metadata validation can be disabled by config."""
|
||||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||||
# This should not throw
|
# This should not throw
|
||||||
self.handler._validate_metadata()
|
self.provider._validate_metadata()
|
||||||
|
|
||||||
def test_redirect_request(self):
|
def test_redirect_request(self):
|
||||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||||
req = Mock(spec=["addCookie"])
|
req = Mock(spec=["addCookie"])
|
||||||
url = self.get_success(
|
url = self.get_success(
|
||||||
self.handler.handle_redirect_request(req, b"http://client/redirect")
|
self.provider.handle_redirect_request(req, b"http://client/redirect")
|
||||||
)
|
)
|
||||||
url = urlparse(url)
|
url = urlparse(url)
|
||||||
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
||||||
@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# ensure that we are correctly testing the fallback when "get_extra_attributes"
|
# ensure that we are correctly testing the fallback when "get_extra_attributes"
|
||||||
# is not implemented.
|
# is not implemented.
|
||||||
mapping_provider = self.handler._user_mapping_provider
|
mapping_provider = self.provider._user_mapping_provider
|
||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
_ = mapping_provider.get_extra_attributes
|
_ = mapping_provider.get_extra_attributes
|
||||||
|
|
||||||
@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
"username": username,
|
"username": username,
|
||||||
}
|
}
|
||||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, request, client_redirect_url, None,
|
expected_user_id, request, client_redirect_url, None,
|
||||||
)
|
)
|
||||||
self.handler._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
self.handler._fetch_userinfo.assert_not_called()
|
self.provider._fetch_userinfo.assert_not_called()
|
||||||
self.render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle mapping errors
|
# Handle mapping errors
|
||||||
with patch.object(
|
with patch.object(
|
||||||
self.handler,
|
self.provider,
|
||||||
"_remote_id_from_userinfo",
|
"_remote_id_from_userinfo",
|
||||||
new=Mock(side_effect=MappingException()),
|
new=Mock(side_effect=MappingException()),
|
||||||
):
|
):
|
||||||
@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
self.assertRenderedError("mapping_error")
|
self.assertRenderedError("mapping_error")
|
||||||
|
|
||||||
# Handle ID token errors
|
# Handle ID token errors
|
||||||
self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_token")
|
self.assertRenderedError("invalid_token")
|
||||||
|
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
self.handler._exchange_code.reset_mock()
|
self.provider._exchange_code.reset_mock()
|
||||||
self.handler._parse_id_token.reset_mock()
|
self.provider._parse_id_token.reset_mock()
|
||||||
self.handler._fetch_userinfo.reset_mock()
|
self.provider._fetch_userinfo.reset_mock()
|
||||||
|
|
||||||
# With userinfo fetching
|
# With userinfo fetching
|
||||||
self.handler._scopes = [] # do not ask the "openid" scope
|
self.provider._scopes = [] # do not ask the "openid" scope
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, request, client_redirect_url, None,
|
expected_user_id, request, client_redirect_url, None,
|
||||||
)
|
)
|
||||||
self.handler._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.handler._parse_id_token.assert_not_called()
|
self.provider._parse_id_token.assert_not_called()
|
||||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
self.provider._fetch_userinfo.assert_called_once_with(token)
|
||||||
self.render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle userinfo fetching error
|
# Handle userinfo fetching error
|
||||||
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("fetch_error")
|
self.assertRenderedError("fetch_error")
|
||||||
|
|
||||||
# Handle code exchange failure
|
# Handle code exchange failure
|
||||||
from synapse.handlers.oidc_handler import OidcError
|
from synapse.handlers.oidc_handler import OidcError
|
||||||
|
|
||||||
self.handler._exchange_code = simple_async_mock(
|
self.provider._exchange_code = simple_async_mock(
|
||||||
raises=OidcError("invalid_request")
|
raises=OidcError("invalid_request")
|
||||||
)
|
)
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
||||||
)
|
)
|
||||||
code = "code"
|
code = "code"
|
||||||
ret = self.get_success(self.handler._exchange_code(code))
|
ret = self.get_success(self.provider._exchange_code(code))
|
||||||
kwargs = self.http_client.request.call_args[1]
|
kwargs = self.http_client.request.call_args[1]
|
||||||
|
|
||||||
self.assertEqual(ret, token)
|
self.assertEqual(ret, token)
|
||||||
@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
from synapse.handlers.oidc_handler import OidcError
|
from synapse.handlers.oidc_handler import OidcError
|
||||||
|
|
||||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||||
self.assertEqual(exc.value.error, "foo")
|
self.assertEqual(exc.value.error, "foo")
|
||||||
self.assertEqual(exc.value.error_description, "bar")
|
self.assertEqual(exc.value.error_description, "bar")
|
||||||
|
|
||||||
@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||||
self.assertEqual(exc.value.error, "server_error")
|
self.assertEqual(exc.value.error, "server_error")
|
||||||
|
|
||||||
# Internal server error with JSON body
|
# Internal server error with JSON body
|
||||||
@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||||
self.assertEqual(exc.value.error, "internal_server_error")
|
self.assertEqual(exc.value.error, "internal_server_error")
|
||||||
|
|
||||||
# 4xx error without "error" field
|
# 4xx error without "error" field
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
||||||
)
|
)
|
||||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||||
self.assertEqual(exc.value.error, "server_error")
|
self.assertEqual(exc.value.error, "server_error")
|
||||||
|
|
||||||
# 2xx error with "error" field
|
# 2xx error with "error" field
|
||||||
@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
||||||
self.assertEqual(exc.value.error, "some_error")
|
self.assertEqual(exc.value.error, "some_error")
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
"username": "foo",
|
"username": "foo",
|
||||||
"phone": "1234567",
|
"phone": "1234567",
|
||||||
}
|
}
|
||||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
|
|||||||
from synapse.handlers.oidc_handler import OidcSessionData
|
from synapse.handlers.oidc_handler import OidcSessionData
|
||||||
|
|
||||||
handler = hs.get_oidc_handler()
|
handler = hs.get_oidc_handler()
|
||||||
handler._exchange_code = simple_async_mock(return_value={})
|
provider = handler._provider
|
||||||
handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
provider._exchange_code = simple_async_mock(return_value={})
|
||||||
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
|
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||||
|
|
||||||
state = "state"
|
state = "state"
|
||||||
session = handler._token_generator.generate_oidc_session_token(
|
session = handler._token_generator.generate_oidc_session_token(
|
||||||
|
Loading…
Reference in New Issue
Block a user