cas: support setting display name (#6114)

Now, the CAS server can return an attribute stating what's the desired displayname, instead of using the username directly.
This commit is contained in:
Valérian Rousset 2019-10-11 13:33:12 +02:00 committed by Richard van der Hoff
parent a0d0ba7862
commit be9b55e0d2
4 changed files with 8 additions and 1 deletions

1
changelog.d/6114.feature Normal file
View File

@ -0,0 +1 @@
CAS login now provides a default display name for users if a `displayname_attribute` is set in the configuration file.

View File

@ -1220,6 +1220,7 @@ saml2_config:
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448" # service_url: "https://homeserver.domain.com:8448"
# #displayname_attribute: name
# #required_attributes: # #required_attributes:
# # name: value # # name: value

View File

@ -30,11 +30,13 @@ class CasConfig(Config):
self.cas_enabled = cas_config.get("enabled", True) self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"] self.cas_service_url = cas_config["service_url"]
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes", {}) self.cas_required_attributes = cas_config.get("required_attributes", {})
else: else:
self.cas_enabled = False self.cas_enabled = False
self.cas_server_url = None self.cas_server_url = None
self.cas_service_url = None self.cas_service_url = None
self.cas_displayname_attribute = None
self.cas_required_attributes = {} self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
@ -45,6 +47,7 @@ class CasConfig(Config):
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448" # service_url: "https://homeserver.domain.com:8448"
# #displayname_attribute: name
# #required_attributes: # #required_attributes:
# # name: value # # name: value
""" """

View File

@ -377,6 +377,7 @@ class CasTicketServlet(RestServlet):
super(CasTicketServlet, self).__init__() super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client() self._http_client = hs.get_simple_http_client()
@ -400,6 +401,7 @@ class CasTicketServlet(RestServlet):
def handle_cas_response(self, request, cas_response_body, client_redirect_url): def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body) user, attributes = self.parse_cas_response(cas_response_body)
displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items(): for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden # If required attribute was not in CAS Response - Forbidden
@ -414,7 +416,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth( return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url user, request, client_redirect_url, displayname
) )
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):