mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-02-27 12:11:07 -05:00
Support multiple required attributes in CAS response, and in a nicer config format too
This commit is contained in:
parent
76421c496d
commit
01a5f1991c
@ -27,28 +27,17 @@ class CasConfig(Config):
|
|||||||
if cas_config:
|
if cas_config:
|
||||||
self.cas_enabled = True
|
self.cas_enabled = True
|
||||||
self.cas_server_url = cas_config["server_url"]
|
self.cas_server_url = cas_config["server_url"]
|
||||||
|
self.cas_required_attributes = cas_config.get("required_attributes", None)
|
||||||
if "required_attribute" in cas_config:
|
|
||||||
self.cas_required_attribute = cas_config["required_attribute"]
|
|
||||||
else:
|
|
||||||
self.cas_required_attribute = None
|
|
||||||
|
|
||||||
if "required_attribute_value" in cas_config:
|
|
||||||
self.cas_required_attribute_value = cas_config["required_attribute_value"]
|
|
||||||
else:
|
|
||||||
self.cas_required_attribute_value = None
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.cas_enabled = False
|
self.cas_enabled = False
|
||||||
self.cas_server_url = None
|
self.cas_server_url = None
|
||||||
self.cas_required_attribute = None
|
self.cas_required_attributes = {}
|
||||||
self.cas_required_attribute_value = None
|
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
return """
|
return """
|
||||||
# Enable CAS for registration and login.
|
# Enable CAS for registration and login.
|
||||||
#cas_config:
|
#cas_config:
|
||||||
# server_url: "https://cas-server.com"
|
# server_url: "https://cas-server.com"
|
||||||
# #required_attribute: something
|
# #required_attributes:
|
||||||
# #required_attribute_value: true
|
# # name: value
|
||||||
"""
|
"""
|
||||||
|
@ -46,8 +46,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
self.saml2_enabled = hs.config.saml2_enabled
|
self.saml2_enabled = hs.config.saml2_enabled
|
||||||
self.cas_enabled = hs.config.cas_enabled
|
self.cas_enabled = hs.config.cas_enabled
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
self.cas_server_url = hs.config.cas_server_url
|
||||||
self.cas_required_attribute = hs.config.cas_required_attribute
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self.cas_required_attribute_value = hs.config.cas_required_attribute_value
|
|
||||||
self.servername = hs.config.server_name
|
self.servername = hs.config.server_name
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
@ -128,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
def do_cas_login(self, cas_response_body):
|
def do_cas_login(self, cas_response_body):
|
||||||
(user, attributes) = self.parse_cas_response(cas_response_body)
|
(user, attributes) = self.parse_cas_response(cas_response_body)
|
||||||
|
|
||||||
if self.cas_required_attribute is not None:
|
for required_attribute in self.cas_required_attributes:
|
||||||
# If required attribute was not in CAS Response - Forbidden
|
# If required attribute was not in CAS Response - Forbidden
|
||||||
if self.cas_required_attribute not in attributes:
|
if required_attribute not in attributes:
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
# Also need to check value
|
# Also need to check value
|
||||||
if self.cas_required_attribute_value is not None:
|
if self.cas_required_attributes[required_attribute] is not None:
|
||||||
actualValue = attributes[self.cas_required_attribute]
|
actualValue = attributes[required_attribute]
|
||||||
# If required attribute value does not match expected - Forbidden
|
# If required attribute value does not match expected - Forbidden
|
||||||
if self.cas_required_attribute_value != actualValue:
|
if self.cas_required_attributes[required_attribute] != actualValue:
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user