Combine the CAS & SAML implementations for required attributes. (#9326)

This commit is contained in:
Patrick Cloke 2021-02-11 10:05:15 -05:00 committed by GitHub
parent 80d6dc9783
commit 6dade80048
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 245 additions and 77 deletions

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET
import attr
@ -49,7 +49,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
class CasHandler:
@ -169,7 +169,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@ -182,7 +182,7 @@ class CasHandler:
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
attributes.setdefault(tag, []).append(attribute.text)
# Ensure a user was found.
if user is None:
@ -303,29 +303,10 @@ class CasHandler:
# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
if not self._sso_handler.check_required_attributes(
request, cas_response.attributes, self._cas_required_attributes
):
return
# Call the mapper to register/login the user
@ -372,9 +353,10 @@ class CasHandler:
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
# Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get(
self._cas_displayname_attribute, None
)
self._cas_displayname_attribute, [None]
)[0]
return UserAttributes(localpart=localpart, display_name=display_name)