Complete the SAML2 implementation (#5422)

* SAML2 Improvements and redirect stuff

Signed-off-by: Alexander Trost <galexrt@googlemail.com>

* Code cleanups and simplifications.

Also: share the saml client between redirect and response handlers.

* changelog

* Revert redundant changes to static js

* Move all the saml stuff out to a centralised handler

* Add support for tracking SAML2 sessions.

This allows us to correctly handle `allow_unsolicited: False`.

* update sample config

* cleanups

* update sample config

* rename BaseSSORedirectServlet for consistency

* Address review comments
This commit is contained in:
Richard van der Hoff 2019-07-02 11:18:11 +01:00 committed by GitHub
commit 6eecb6e500
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 231 additions and 45 deletions

View file

@ -86,6 +86,7 @@ class LoginRestServlet(RestServlet):
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
@ -97,6 +98,9 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@ -351,27 +355,49 @@ class LoginRestServlet(RestServlet):
defer.returnValue(result)
class CasRedirectServlet(RestServlet):
class BaseSSORedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
sso_url = self.get_sso_url(client_redirect_url)
request.redirect(sso_url)
finish_request(request)
def get_sso_url(self, client_redirect_url):
"""Get the URL to redirect to, to perform SSO auth
Args:
client_redirect_url (bytes): the URL that we should redirect the
client to when everything is done
Returns:
bytes: URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
def get_sso_url(self, client_redirect_url):
client_redirect_url_param = urllib.parse.urlencode(
{b"redirectUrl": args[b"redirectUrl"][0]}
{b"redirectUrl": client_redirect_url}
).encode("ascii")
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
service_param = urllib.parse.urlencode(
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
).encode("ascii")
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request)
return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@ -454,6 +480,16 @@ class CasTicketServlet(RestServlet):
return user, attributes
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
def get_sso_url(self, client_redirect_url):
return self._saml_handler.handle_redirect_request(client_redirect_url)
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
@ -529,3 +565,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)