Allow optional config params for a required attribute and it's value, if specified any CAS user must have the given attribute and the value must equal

This commit is contained in:
Steven Hammerton 2015-10-12 11:11:49 +01:00
parent 7845f62c22
commit 76421c496d
2 changed files with 30 additions and 1 deletions

View File

@ -27,13 +27,28 @@ class CasConfig(Config):
if cas_config: if cas_config:
self.cas_enabled = True self.cas_enabled = True
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
if "required_attribute" in cas_config:
self.cas_required_attribute = cas_config["required_attribute"]
else:
self.cas_required_attribute = None
if "required_attribute_value" in cas_config:
self.cas_required_attribute_value = cas_config["required_attribute_value"]
else:
self.cas_required_attribute_value = None
else: else:
self.cas_enabled = False self.cas_enabled = False
self.cas_server_url = None self.cas_server_url = None
self.cas_required_attribute = None
self.cas_required_attribute_value = None
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable CAS for registration and login. # Enable CAS for registration and login.
#cas_config: #cas_config:
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# #required_attribute: something
# #required_attribute_value: true
""" """

View File

@ -45,8 +45,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_required_attribute = hs.config.cas_required_attribute
self.cas_required_attribute_value = hs.config.cas_required_attribute_value
self.servername = hs.config.server_name self.servername = hs.config.server_name
def on_GET(self, request): def on_GET(self, request):
@ -126,6 +127,19 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_cas_login(self, cas_response_body): def do_cas_login(self, cas_response_body):
(user, attributes) = self.parse_cas_response(cas_response_body) (user, attributes) = self.parse_cas_response(cas_response_body)
if self.cas_required_attribute is not None:
# If required attribute was not in CAS Response - Forbidden
if self.cas_required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if self.cas_required_attribute_value is not None:
actualValue = attributes[self.cas_required_attribute]
# If required attribute value does not match expected - Forbidden
if self.cas_required_attribute_value != actualValue:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)