mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-11 09:55:46 -05:00
Combine the SSO Redirect Servlets (#9015)
* Implement CasHandler.handle_redirect_request ... to make it match OidcHandler and SamlHandler * Clean up interface for OidcHandler.handle_redirect_request Make it accept `client_redirect_url=None`. * Clean up interface for `SamlHandler.handle_redirect_request` ... bring it into line with CAS and OIDC by making it take a Request parameter, move the magic for `client_redirect_url` for UIA into the handler, and fix the return type to be a `str` rather than a `bytes`. * Define a common protocol for SSO auth provider impls * Give SsoIdentityProvider an ID and register them * Combine the SSO Redirect servlets Now that the SsoHandler knows about the identity providers, we can combine the various *RedirectServlets into a single implementation which delegates to the right IdP. * changelog
This commit is contained in:
parent
31b1905e13
commit
d2c616a413
8 changed files with 174 additions and 113 deletions
|
|
@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
|
|||
return result
|
||||
|
||||
|
||||
class BaseSSORedirectServlet(RestServlet):
|
||||
"""Common base class for /login/sso/redirect impls"""
|
||||
|
||||
class SsoRedirectServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
# make sure that the relevant handlers are instantiated, so that they
|
||||
# register themselves with the main SSOHandler.
|
||||
if hs.config.cas_enabled:
|
||||
hs.get_cas_handler()
|
||||
elif hs.config.saml2_enabled:
|
||||
hs.get_saml_handler()
|
||||
elif hs.config.oidc_enabled:
|
||||
hs.get_oidc_handler()
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
client_redirect_url = args[b"redirectUrl"][0]
|
||||
sso_url = await self.get_sso_url(request, client_redirect_url)
|
||||
client_redirect_url = parse_string(
|
||||
request, "redirectUrl", required=True, encoding=None
|
||||
)
|
||||
sso_url = await self._sso_handler.handle_redirect_request(
|
||||
request, client_redirect_url
|
||||
)
|
||||
logger.info("Redirecting to %s", sso_url)
|
||||
request.redirect(sso_url)
|
||||
finish_request(request)
|
||||
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
"""Get the URL to redirect to, to perform SSO auth
|
||||
|
||||
Args:
|
||||
request: The client request to redirect.
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to when everything is done
|
||||
|
||||
Returns:
|
||||
URL to redirect to
|
||||
"""
|
||||
# to be implemented by subclasses
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CasRedirectServlet(BaseSSORedirectServlet):
|
||||
def __init__(self, hs):
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
return self._cas_handler.get_redirect_url(
|
||||
{"redirectUrl": client_redirect_url}
|
||||
).encode("ascii")
|
||||
|
||||
|
||||
class CasTicketServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
|
||||
|
|
@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
|
|||
)
|
||||
|
||||
|
||||
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class OIDCRedirectServlet(BaseSSORedirectServlet):
|
||||
"""Implementation for /login/sso/redirect for the OIDC login flow."""
|
||||
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
return await self._oidc_handler.handle_redirect_request(
|
||||
request, client_redirect_url
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
SsoRedirectServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
CasRedirectServlet(hs).register(http_server)
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
elif hs.config.saml2_enabled:
|
||||
SAMLRedirectServlet(hs).register(http_server)
|
||||
elif hs.config.oidc_enabled:
|
||||
OIDCRedirectServlet(hs).register(http_server)
|
||||
|
|
|
|||
|
|
@ -14,15 +14,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.handlers.sso import SsoIdentityProvider
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.http.servlet import RestServlet, parse_string
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -35,7 +40,7 @@ class AuthRestServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
|
@ -85,31 +90,20 @@ class AuthRestServlet(RestServlet):
|
|||
elif stagetype == LoginType.SSO:
|
||||
# Display a confirmation page which prompts the user to
|
||||
# re-authenticate with their SSO provider.
|
||||
|
||||
if self._cas_enabled:
|
||||
# Generate a request to CAS that redirects back to an endpoint
|
||||
# to verify the successful authentication.
|
||||
sso_redirect_url = self._cas_handler.get_redirect_url(
|
||||
{"session": session},
|
||||
)
|
||||
|
||||
sso_auth_provider = self._cas_handler # type: SsoIdentityProvider
|
||||
elif self._saml_enabled:
|
||||
# Some SAML identity providers (e.g. Google) require a
|
||||
# RelayState parameter on requests. It is not necessary here, so
|
||||
# pass in a dummy redirect URL (which will never get used).
|
||||
client_redirect_url = b"unused"
|
||||
sso_redirect_url = self._saml_handler.handle_redirect_request(
|
||||
client_redirect_url, session
|
||||
)
|
||||
|
||||
sso_auth_provider = self._saml_handler
|
||||
elif self._oidc_enabled:
|
||||
client_redirect_url = b""
|
||||
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
|
||||
request, client_redirect_url, session
|
||||
)
|
||||
|
||||
sso_auth_provider = self._oidc_handler
|
||||
else:
|
||||
raise SynapseError(400, "Homeserver not configured for SSO.")
|
||||
|
||||
sso_redirect_url = await sso_auth_provider.handle_redirect_request(
|
||||
request, None, session
|
||||
)
|
||||
|
||||
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue