Combine the SSO Redirect Servlets (#9015)

* Implement CasHandler.handle_redirect_request

... to make it match OidcHandler and SamlHandler

* Clean up interface for OidcHandler.handle_redirect_request

Make it accept `client_redirect_url=None`.

* Clean up interface for `SamlHandler.handle_redirect_request`

... bring it into line with CAS and OIDC by making it take a Request parameter,
move the magic for `client_redirect_url` for UIA into the handler, and fix the
return type to be a `str` rather than a `bytes`.

* Define a common protocol for SSO auth provider impls

* Give SsoIdentityProvider an ID and register them

* Combine the SSO Redirect servlets

Now that the SsoHandler knows about the identity providers, we can combine the
various *RedirectServlets into a single implementation which delegates to the
right IdP.

* changelog
This commit is contained in:
Richard van der Hoff 2021-01-04 18:13:49 +00:00 committed by GitHub
parent 31b1905e13
commit d2c616a413
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 174 additions and 113 deletions

1
changelog.d/9015.feature Normal file
View File

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View File

@ -75,10 +75,12 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
# identifier for the external_ids table # identifier for the external_ids table
self._auth_provider_id = "cas" self.idp_id = "cas"
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
def _build_service_param(self, args: Dict[str, str]) -> str: def _build_service_param(self, args: Dict[str, str]) -> str:
""" """
Generates a value to use as the "service" parameter when redirecting or Generates a value to use as the "service" parameter when redirecting or
@ -105,7 +107,7 @@ class CasHandler:
Args: Args:
ticket: The CAS ticket from the client. ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL. service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`. Should be the same as those passed to `handle_redirect_request`.
Raises: Raises:
CasError: If there's an error parsing the CAS response. CasError: If there's an error parsing the CAS response.
@ -184,16 +186,31 @@ class CasHandler:
return CasResponse(user, attributes) return CasResponse(user, attributes)
def get_redirect_url(self, service_args: Dict[str, str]) -> str: async def handle_redirect_request(
""" self,
Generates a URL for the CAS server where the client should be redirected. request: SynapseRequest,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Generates a URL for the CAS server where the client should be redirected.
Args: Args:
service_args: Additional arguments to include in the final redirect URL. request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns: Returns:
The URL to redirect the client to. URL to redirect to
""" """
if ui_auth_session_id:
service_args = {"session": ui_auth_session_id}
else:
assert client_redirect_url
service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
args = urllib.parse.urlencode( args = urllib.parse.urlencode(
{"service": self._build_service_param(service_args)} {"service": self._build_service_param(service_args)}
) )
@ -275,7 +292,7 @@ class CasHandler:
# first check if we're doing a UIA # first check if we're doing a UIA
if session: if session:
return await self._sso_handler.complete_sso_ui_auth_request( return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, cas_response.username, session, request, self.idp_id, cas_response.username, session, request,
) )
# otherwise, we're handling a login request. # otherwise, we're handling a login request.
@ -375,7 +392,7 @@ class CasHandler:
return None return None
await self._sso_handler.complete_sso_login_request( await self._sso_handler.complete_sso_login_request(
self._auth_provider_id, self.idp_id,
cas_response.username, cas_response.username,
request, request,
client_redirect_url, client_redirect_url,

View File

@ -119,10 +119,12 @@ class OidcHandler(BaseHandler):
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table # identifier for the external_ids table
self._auth_provider_id = "oidc" self.idp_id = "oidc"
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
def _validate_metadata(self): def _validate_metadata(self):
"""Verifies the provider metadata. """Verifies the provider metadata.
@ -475,7 +477,7 @@ class OidcHandler(BaseHandler):
async def handle_redirect_request( async def handle_redirect_request(
self, self,
request: SynapseRequest, request: SynapseRequest,
client_redirect_url: bytes, client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None, ui_auth_session_id: Optional[str] = None,
) -> str: ) -> str:
"""Handle an incoming request to /login/sso/redirect """Handle an incoming request to /login/sso/redirect
@ -499,7 +501,7 @@ class OidcHandler(BaseHandler):
request: the incoming request from the browser. request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie. We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to client_redirect_url: the URL that we should redirect the client to
when everything is done when everything is done (or None for UI Auth)
ui_auth_session_id: The session ID of the ongoing UI Auth (or ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login). None if this is a login).
@ -511,6 +513,9 @@ class OidcHandler(BaseHandler):
state = generate_token() state = generate_token()
nonce = generate_token() nonce = generate_token()
if not client_redirect_url:
client_redirect_url = b""
cookie = self._generate_oidc_session_token( cookie = self._generate_oidc_session_token(
state=state, state=state,
nonce=nonce, nonce=nonce,
@ -682,7 +687,7 @@ class OidcHandler(BaseHandler):
return return
return await self._sso_handler.complete_sso_ui_auth_request( return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, remote_user_id, ui_auth_session_id, request self.idp_id, remote_user_id, ui_auth_session_id, request
) )
# otherwise, it's a login # otherwise, it's a login
@ -923,7 +928,7 @@ class OidcHandler(BaseHandler):
extra_attributes = await get_extra_attributes(userinfo, token) extra_attributes = await get_extra_attributes(userinfo, token)
await self._sso_handler.complete_sso_login_request( await self._sso_handler.complete_sso_login_request(
self._auth_provider_id, self.idp_id,
remote_user_id, remote_user_id,
request, request,
client_redirect_url, client_redirect_url,

View File

@ -73,27 +73,38 @@ class SamlHandler(BaseHandler):
) )
# identifier for the external_ids table # identifier for the external_ids table
self._auth_provider_id = "saml" self.idp_id = "saml"
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
def handle_redirect_request( async def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None self,
) -> bytes: request: SynapseRequest,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect """Handle an incoming request to /login/sso/redirect
Args: Args:
request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the client_redirect_url: the URL that we should redirect the
client to when everything is done client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login). None if this is a login).
Returns: Returns:
URL to redirect to URL to redirect to
""" """
if not client_redirect_url:
# Some SAML identity providers (e.g. Google) require a
# RelayState parameter on requests, so pass in a dummy redirect URL
# (which will never get used).
client_redirect_url = b"unused"
reqid, info = self._saml_client.prepare_for_authenticate( reqid, info = self._saml_client.prepare_for_authenticate(
entityid=self._saml_idp_entityid, relay_state=client_redirect_url entityid=self._saml_idp_entityid, relay_state=client_redirect_url
) )
@ -210,7 +221,7 @@ class SamlHandler(BaseHandler):
return return
return await self._sso_handler.complete_sso_ui_auth_request( return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, self.idp_id,
remote_user_id, remote_user_id,
current_session.ui_auth_session_id, current_session.ui_auth_session_id,
request, request,
@ -306,7 +317,7 @@ class SamlHandler(BaseHandler):
return None return None
await self._sso_handler.complete_sso_login_request( await self._sso_handler.complete_sso_login_request(
self._auth_provider_id, self.idp_id,
remote_user_id, remote_user_id,
request, request,
client_redirect_url, client_redirect_url,

View File

@ -12,15 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
import attr import attr
from typing_extensions import NoReturn from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request from twisted.web.http import Request
from synapse.api.errors import RedirectException, SynapseError from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@ -40,6 +41,53 @@ class MappingException(Exception):
""" """
class SsoIdentityProvider(Protocol):
"""Abstract base class to be implemented by SSO Identity Providers
An Identity Provider, or IdP, is an external HTTP service which authenticates a user
to say whether they should be allowed to log in, or perform a given action.
Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
and CAS.
The main entry point is `handle_redirect_request`, which should return a URI to
redirect the user's browser to the IdP's authentication page.
Each IdP should be registered with the SsoHandler via
`hs.get_sso_handler().register_identity_provider()`, so that requests to
`/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
"""
@property
@abc.abstractmethod
def idp_id(self) -> str:
"""A unique identifier for this SSO provider
Eg, "saml", "cas", "github"
"""
@abc.abstractmethod
async def handle_redirect_request(
self,
request: SynapseRequest,
client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
Args:
request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
URL to redirect to
"""
raise NotImplementedError()
@attr.s @attr.s
class UserAttributes: class UserAttributes:
# the localpart of the mxid that the mapper has assigned to the user. # the localpart of the mxid that the mapper has assigned to the user.
@ -100,6 +148,14 @@ class SsoHandler:
# a map from session id to session data # a map from session id to session data
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
# map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
def register_identity_provider(self, p: SsoIdentityProvider):
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
def render_error( def render_error(
self, self,
request: Request, request: Request,
@ -124,6 +180,32 @@ class SsoHandler:
) )
respond_with_html(request, code, html) respond_with_html(request, code, html)
async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes,
) -> str:
"""Handle a request to /login/sso/redirect
Args:
request: incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login.
Returns:
the URI to redirect to
"""
if not self._identity_providers:
raise SynapseError(
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
)
# if we only have one auth provider, redirect to it directly
if len(self._identity_providers) == 1:
ap = next(iter(self._identity_providers.values()))
return await ap.handle_redirect_request(request, client_redirect_url)
# otherwise, we have a configuration error
raise Exception("Multiple SSO identity providers have been configured!")
async def get_sso_user_by_remote_user_id( async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str self, auth_provider_id: str, remote_user_id: str
) -> Optional[str]: ) -> Optional[str]:

View File

@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
return result return result
class BaseSSORedirectServlet(RestServlet): class SsoRedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
if hs.config.cas_enabled:
hs.get_cas_handler()
elif hs.config.saml2_enabled:
hs.get_saml_handler()
elif hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
async def on_GET(self, request: SynapseRequest): async def on_GET(self, request: SynapseRequest):
args = request.args client_redirect_url = parse_string(
if b"redirectUrl" not in args: request, "redirectUrl", required=True, encoding=None
return 400, "Redirect URL not specified for SSO auth" )
client_redirect_url = args[b"redirectUrl"][0] sso_url = await self._sso_handler.handle_redirect_request(
sso_url = await self.get_sso_url(request, client_redirect_url) request, client_redirect_url
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url) request.redirect(sso_url)
finish_request(request) finish_request(request)
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
request: The client request to redirect.
client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url}
).encode("ascii")
class CasTicketServlet(RestServlet): class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True) PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
) )
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
class OIDCRedirectServlet(BaseSSORedirectServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler()
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return await self._oidc_handler.handle_redirect_request(
request, client_redirect_url
)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
elif hs.config.oidc_enabled:
OIDCRedirectServlet(hs).register(http_server)

View File

@ -14,15 +14,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string from synapse.http.servlet import RestServlet, parse_string
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,7 +40,7 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -85,31 +90,20 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO: elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to # Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider. # re-authenticate with their SSO provider.
if self._cas_enabled: if self._cas_enabled:
# Generate a request to CAS that redirects back to an endpoint sso_auth_provider = self._cas_handler # type: SsoIdentityProvider
# to verify the successful authentication.
sso_redirect_url = self._cas_handler.get_redirect_url(
{"session": session},
)
elif self._saml_enabled: elif self._saml_enabled:
# Some SAML identity providers (e.g. Google) require a sso_auth_provider = self._saml_handler
# RelayState parameter on requests. It is not necessary here, so
# pass in a dummy redirect URL (which will never get used).
client_redirect_url = b"unused"
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)
elif self._oidc_enabled: elif self._oidc_enabled:
client_redirect_url = b"" sso_auth_provider = self._oidc_handler
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
request, client_redirect_url, session
)
else: else:
raise SynapseError(400, "Homeserver not configured for SSO.") raise SynapseError(400, "Homeserver not configured for SSO.")
sso_redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session
)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
else: else:

View File

@ -385,7 +385,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url) channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML. # Test that the response is HTML.
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = "" content_type_header_value = ""
for header in channel.result.get("headers", []): for header in channel.result.get("headers", []):
if header[0] == b"Content-Type": if header[0] == b"Content-Type":