mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-07 19:45:03 -04:00
Combine the CAS & SAML implementations for required attributes. (#9326)
This commit is contained in:
parent
80d6dc9783
commit
6dade80048
9 changed files with 245 additions and 77 deletions
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import attr
|
||||
|
@ -49,7 +49,7 @@ class CasError(Exception):
|
|||
@attr.s(slots=True, frozen=True)
|
||||
class CasResponse:
|
||||
username = attr.ib(type=str)
|
||||
attributes = attr.ib(type=Dict[str, Optional[str]])
|
||||
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
|
||||
|
||||
|
||||
class CasHandler:
|
||||
|
@ -169,7 +169,7 @@ class CasHandler:
|
|||
|
||||
# Iterate through the nodes and pull out the user and any extra attributes.
|
||||
user = None
|
||||
attributes = {}
|
||||
attributes = {} # type: Dict[str, List[Optional[str]]]
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
|
@ -182,7 +182,7 @@ class CasHandler:
|
|||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
attributes.setdefault(tag, []).append(attribute.text)
|
||||
|
||||
# Ensure a user was found.
|
||||
if user is None:
|
||||
|
@ -303,29 +303,10 @@ class CasHandler:
|
|||
|
||||
# Ensure that the attributes of the logged in user meet the required
|
||||
# attributes.
|
||||
for required_attribute, required_value in self._cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in cas_response.attributes:
|
||||
self._sso_handler.render_error(
|
||||
request,
|
||||
"unauthorised",
|
||||
"You are not authorised to log in here.",
|
||||
401,
|
||||
)
|
||||
return
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = cas_response.attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
self._sso_handler.render_error(
|
||||
request,
|
||||
"unauthorised",
|
||||
"You are not authorised to log in here.",
|
||||
401,
|
||||
)
|
||||
return
|
||||
if not self._sso_handler.check_required_attributes(
|
||||
request, cas_response.attributes, self._cas_required_attributes
|
||||
):
|
||||
return
|
||||
|
||||
# Call the mapper to register/login the user
|
||||
|
||||
|
@ -372,9 +353,10 @@ class CasHandler:
|
|||
if failures:
|
||||
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
|
||||
|
||||
# Arbitrarily use the first attribute found.
|
||||
display_name = cas_response.attributes.get(
|
||||
self._cas_displayname_attribute, None
|
||||
)
|
||||
self._cas_displayname_attribute, [None]
|
||||
)[0]
|
||||
|
||||
return UserAttributes(localpart=localpart, display_name=display_name)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue