diff --git a/changelog.d/9326.misc b/changelog.d/9326.misc new file mode 100644 index 000000000..768c18d27 --- /dev/null +++ b/changelog.d/9326.misc @@ -0,0 +1 @@ +Share the code for handling required attributes between the CAS and SAML handlers. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index b226890c2..daea848d2 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List + +from synapse.config.sso import SsoAttributeRequirement + from ._base import Config +from ._util import validate_config class CasConfig(Config): @@ -38,12 +43,16 @@ class CasConfig(Config): public_base_url + "_matrix/client/r0/login/cas/ticket" ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") - self.cas_required_attributes = cas_config.get("required_attributes") or {} + required_attributes = cas_config.get("required_attributes") or {} + self.cas_required_attributes = _parsed_required_attributes_def( + required_attributes + ) + else: self.cas_server_url = None self.cas_service_url = None self.cas_displayname_attribute = None - self.cas_required_attributes = {} + self.cas_required_attributes = [] def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ @@ -75,3 +84,22 @@ class CasConfig(Config): # userGroup: "staff" # department: None """ + + +# CAS uses a legacy required attributes mapping, not the one provided by +# SsoAttributeRequirement. +REQUIRED_ATTRIBUTES_SCHEMA = { + "type": "object", + "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]}, +} + + +def _parsed_required_attributes_def( + required_attributes: Any, +) -> List[SsoAttributeRequirement]: + validate_config( + REQUIRED_ATTRIBUTES_SCHEMA, + required_attributes, + config_path=("cas_config", "required_attributes"), + ) + return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()] diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index ad865a667..1820614bc 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -17,8 +17,7 @@ import logging from typing import Any, List -import attr - +from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -396,32 +395,18 @@ class SAML2Config(Config): } -@attr.s(frozen=True) -class SamlAttributeRequirement: - """Object describing a single requirement for SAML attributes.""" - - attribute = attr.ib(type=str) - value = attr.ib(type=str) - - JSON_SCHEMA = { - "type": "object", - "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, - "required": ["attribute", "value"], - } - - ATTRIBUTE_REQUIREMENTS_SCHEMA = { "type": "array", - "items": SamlAttributeRequirement.JSON_SCHEMA, + "items": SsoAttributeRequirement.JSON_SCHEMA, } def _parse_attribute_requirements_def( attribute_requirements: Any, -) -> List[SamlAttributeRequirement]: +) -> List[SsoAttributeRequirement]: validate_config( ATTRIBUTE_REQUIREMENTS_SCHEMA, attribute_requirements, - config_path=["saml2_config", "attribute_requirements"], + config_path=("saml2_config", "attribute_requirements"), ) - return [SamlAttributeRequirement(**x) for x in attribute_requirements] + return [SsoAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 6c60c6fea..b94d3cd5e 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Optional + +import attr from ._base import Config +@attr.s(frozen=True) +class SsoAttributeRequirement: + """Object describing a single requirement for SSO attributes.""" + + attribute = attr.ib(type=str) + # If a value is not given, than the attribute must simply exist. + value = attr.ib(type=Optional[str]) + + JSON_SCHEMA = { + "type": "object", + "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, + "required": ["attribute", "value"], + } + + class SSOConfig(Config): """SSO Configuration """ diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index bd35d1fb8..81ed44ac8 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from xml.etree import ElementTree as ET import attr @@ -49,7 +49,7 @@ class CasError(Exception): @attr.s(slots=True, frozen=True) class CasResponse: username = attr.ib(type=str) - attributes = attr.ib(type=Dict[str, Optional[str]]) + attributes = attr.ib(type=Dict[str, List[Optional[str]]]) class CasHandler: @@ -169,7 +169,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} + attributes = {} # type: Dict[str, List[Optional[str]]] for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -182,7 +182,7 @@ class CasHandler: tag = attribute.tag if "}" in tag: tag = tag.split("}")[1] - attributes[tag] = attribute.text + attributes.setdefault(tag, []).append(attribute.text) # Ensure a user was found. if user is None: @@ -303,29 +303,10 @@ class CasHandler: # Ensure that the attributes of the logged in user meet the required # attributes. - for required_attribute, required_value in self._cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in cas_response.attributes: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return - - # Also need to check value - if required_value is not None: - actual_value = cas_response.attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return + if not self._sso_handler.check_required_attributes( + request, cas_response.attributes, self._cas_required_attributes + ): + return # Call the mapper to register/login the user @@ -372,9 +353,10 @@ class CasHandler: if failures: raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + # Arbitrarily use the first attribute found. display_name = cas_response.attributes.get( - self._cas_displayname_attribute, None - ) + self._cas_displayname_attribute, [None] + )[0] return UserAttributes(localpart=localpart, display_name=display_name) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index e88fd5974..78f130e15 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -23,7 +23,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.config.saml2_config import SamlAttributeRequirement from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -239,12 +238,10 @@ class SamlHandler(BaseHandler): # Ensure that the attributes of the logged in user meet the required # attributes. - for requirement in self._saml2_attribute_requirements: - if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._sso_handler.render_error( - request, "unauthorised", "You are not authorised to log in here." - ) - return + if not self._sso_handler.check_required_attributes( + request, saml2_auth.ava, self._saml2_attribute_requirements + ): + return # Call the mapper to register/login the user try: @@ -373,21 +370,6 @@ class SamlHandler(BaseHandler): del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: - values = ava.get(req.attribute, []) - for v in values: - if v == req.value: - return True - - logger.info( - "SAML2 attribute %s did not match required value '%s' (was '%s')", - req.attribute, - req.value, - values, - ) - return False - - DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 96ccd991e..a63fd5248 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -16,10 +16,12 @@ import abc import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, Iterable, + List, Mapping, Optional, Set, @@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -893,6 +896,41 @@ class SsoHandler: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie @@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: if not session_id: raise SynapseError(code=400, msg="missing session_id") return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 7baf224f7..6f992291b 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -16,7 +16,7 @@ from mock import Mock from synapse.handlers.cas_handler import CasResponse from tests.test_utils import simple_async_mock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. BASE_URL = "https://synapse/" @@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase): "server_url": SERVER_URL, "service_url": BASE_URL, } + + # Update this config with what's in the default config so that + # override_config works as expected. + cas_config.update(config.get("cas_config", {})) config["cas_config"] = cas_config return config @@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase): "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True ) + @override_config( + { + "cas_config": { + "required_attributes": {"userGroup": "staff", "department": None} + } + } + ) + def test_required_attributes(self): + """The required attributes must be met from the CAS response.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # The response doesn't have the proper userGroup or department. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_not_called() + + # The response doesn't have any department. + cas_response = CasResponse("test_user", {"userGroup": "staff"}) + request.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_not_called() + + # Add the proper attributes and it should succeed. + cas_response = CasResponse( + "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]} + ) + request.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None, new_user=True + ) + def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader"]) + return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index a8d6c0f61..029af2853 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase): ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + @override_config( + { + "saml2_config": { + "attribute_requirements": [ + {"attribute": "userGroup", "value": "staff"}, + {"attribute": "department", "value": "sales"}, + ], + }, + } + ) + def test_attribute_requirements(self): + """The required attributes must be met from the SAML response.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # The response doesn't have the proper userGroup or department. + saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + auth_handler.complete_sso_login.assert_not_called() + + # The response doesn't have the proper department. + saml_response = FakeAuthnResponse( + {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]} + ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + auth_handler.complete_sso_login.assert_not_called() + + # Add the proper attributes and it should succeed. + saml_response = FakeAuthnResponse( + { + "uid": "test_user", + "username": "test_user", + "userGroup": ["staff", "admin"], + "department": ["sales"], + } + ) + request.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None, new_user=True + ) + def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader"]) + return Mock(spec=["getClientIP", "getHeader", "_disconnected"])