Refactor the CAS handler in prep for using the abstracted SSO code. (#8958)

This makes the CAS handler look more like the SAML/OIDC handlers:

* Render errors to users instead of throwing JSON errors.
* Internal reorganization.
This commit is contained in:
Patrick Cloke 2020-12-18 13:09:45 -05:00 committed by GitHub
parent 56e00ca85e
commit 4218473f9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 162 additions and 69 deletions

1
changelog.d/8958.misc Normal file
View File

@ -0,0 +1 @@
Properly store the mapping of external ID to Matrix ID for CAS users.

View File

@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
You should now have a Django project configured to serve CAS authentication with You should now have a Django project configured to serve CAS authentication with
a single user created. a single user created.
## Configure Synapse (and Riot) to use CAS ## Configure Synapse (and Element) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally 1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server: running Django test server:
@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
## Testing the configuration ## Testing the configuration
Then in Riot: Then in Element:
1. Visit the login page with a Riot pointing at your homeserver. 1. Visit the login page with a Element pointing at your homeserver.
2. Click the Single Sign-On button. 2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`. 3. Login using the credentials created with `createsuperuser`.
4. You should be logged in. 4. You should be logged in.

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
import attr
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError from synapse.api.errors import HttpResponseException
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
@ -29,6 +31,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.
"""
def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])
class CasHandler: class CasHandler:
""" """
Utility class for to handle the response from a CAS SSO service. Utility class for to handle the response from a CAS SSO service.
@ -50,6 +72,8 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._sso_handler = hs.get_sso_handler()
def _build_service_param(self, args: Dict[str, str]) -> str: def _build_service_param(self, args: Dict[str, str]) -> str:
""" """
Generates a value to use as the "service" parameter when redirecting or Generates a value to use as the "service" parameter when redirecting or
@ -69,14 +93,20 @@ class CasHandler:
async def _validate_ticket( async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str] self, ticket: str, service_args: Dict[str, str]
) -> Tuple[str, Optional[str]]: ) -> CasResponse:
""" """
Validate a CAS ticket with the server, parse the response, and return the user and display name. Validate a CAS ticket with the server, and return the parsed the response.
Args: Args:
ticket: The CAS ticket from the client. ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL. service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`. Should be the same as those passed to `get_redirect_url`.
Raises:
CasError: If there's an error parsing the CAS response.
Returns:
The parsed CAS response.
""" """
uri = self._cas_server_url + "/proxyValidate" uri = self._cas_server_url + "/proxyValidate"
args = { args = {
@ -89,66 +119,65 @@ class CasHandler:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data
body = pde.response body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
raise CasError("server_error", description) from e
user, attributes = self._parse_cas_response(body) return self._parse_cas_response(body)
displayname = attributes.pop(self._cas_displayname_attribute, None)
for required_attribute, required_value in self._cas_required_attributes.items(): def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return user, displayname
def _parse_cas_response(
self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
""" """
Retrieve the user and other parameters from the CAS response. Retrieve the user and other parameters from the CAS response.
Args: Args:
cas_response_body: The response from the CAS query. cas_response_body: The response from the CAS query.
Raises:
CasError: If there's an error parsing the CAS response.
Returns: Returns:
A tuple of the user and a mapping of other attributes. The parsed CAS response.
""" """
# Ensure the response is valid.
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise CasError(
"missing_service_response",
"root of CAS response is not serviceResponse",
)
success = root[0].tag.endswith("authenticationSuccess")
if not success:
raise CasError("unsucessful_response", "Unsuccessful CAS response")
# Iterate through the nodes and pull out the user and any extra attributes.
user = None user = None
attributes = {} attributes = {}
try: for child in root[0]:
root = ET.fromstring(cas_response_body) if child.tag.endswith("user"):
if not root.tag.endswith("serviceResponse"): user = child.text
raise Exception("root of CAS response is not serviceResponse") if child.tag.endswith("attributes"):
success = root[0].tag.endswith("authenticationSuccess") for attribute in child:
for child in root[0]: # ElementTree library expands the namespace in
if child.tag.endswith("user"): # attribute tags to the full URL of the namespace.
user = child.text # We don't care about namespace here and it will always
if child.tag.endswith("attributes"): # be encased in curly braces, so we remove them.
for attribute in child: tag = attribute.tag
# ElementTree library expands the namespace in if "}" in tag:
# attribute tags to the full URL of the namespace. tag = tag.split("}")[1]
# We don't care about namespace here and it will always attributes[tag] = attribute.text
# be encased in curly braces, so we remove them.
tag = attribute.tag # Ensure a user was found.
if "}" in tag: if user is None:
tag = tag.split("}")[1] raise CasError("no_user", "CAS response does not contain user")
attributes[tag] = attribute.text
if user is None: return CasResponse(user, attributes)
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
def get_redirect_url(self, service_args: Dict[str, str]) -> str: def get_redirect_url(self, service_args: Dict[str, str]) -> str:
""" """
@ -201,7 +230,68 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url args["redirectUrl"] = client_redirect_url
if session: if session:
args["session"] = session args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
try:
cas_response = await self._validate_ticket(ticket, args)
except CasError as e:
logger.exception("Could not validate ticket")
self._sso_handler.render_error(request, e.error, e.error_description, 401)
return
await self._handle_cas_response(
request, cas_response, client_redirect_url, session
)
async def _handle_cas_response(
self,
request: SynapseRequest,
cas_response: CasResponse,
client_redirect_url: Optional[str],
session: Optional[str],
) -> None:
"""Handle a CAS response to a ticket request.
Assumes that the response has been validated. Maps the user onto an MXID,
registering them if necessary, and returns a response to the browser.
Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
cas_response: The parsed CAS response.
client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
session: The session parameter from the `/cas/ticket` HTTP request, if given.
This should be the UI Auth session id.
"""
# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Pull out the user-agent and IP from the request. # Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("") user_agent = request.get_user_agent("")
@ -209,7 +299,7 @@ class CasHandler:
# Get the matrix ID from the CAS username. # Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user( user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address cas_response, user_agent, ip_address
) )
if session: if session:
@ -225,18 +315,13 @@ class CasHandler:
) )
async def _map_cas_user_to_matrix_user( async def _map_cas_user_to_matrix_user(
self, self, cas_response: CasResponse, user_agent: str, ip_address: str,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
) -> str: ) -> str:
""" """
Given a CAS username, retrieve the user ID for it and possibly register the user. Given a CAS username, retrieve the user ID for it and possibly register the user.
Args: Args:
remote_user_id: The username from the CAS response. cas_response: The parsed CAS response.
display_name: The display name from the CAS response.
user_agent: The user agent of the client making the request. user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request. ip_address: The IP address of the client making the request.
@ -244,15 +329,17 @@ class CasHandler:
The user ID associated with this response. The user ID associated with this response.
""" """
localpart = map_username_to_mxid_localpart(remote_user_id) localpart = map_username_to_mxid_localpart(cas_response.username)
user_id = UserID(localpart, self._hostname).to_string() user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id) registered_user_id = await self._auth_handler.check_user_exists(user_id)
displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
# If the user does not exist, register it. # If the user does not exist, register it.
if not registered_user_id: if not registered_user_id:
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, localpart=localpart,
default_display_name=display_name, default_display_name=displayname,
user_agent_ips=[(user_agent, ip_address)], user_agent_ips=[(user_agent, ip_address)],
) )

View File

@ -101,7 +101,11 @@ class SsoHandler:
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
def render_error( def render_error(
self, request, error: str, error_description: Optional[str] = None self,
request: Request,
error: str,
error_description: Optional[str] = None,
code: int = 400,
) -> None: ) -> None:
"""Renders the error template and responds with it. """Renders the error template and responds with it.
@ -113,11 +117,12 @@ class SsoHandler:
We'll respond with an HTML page describing the error. We'll respond with an HTML page describing the error.
error: A technical identifier for this error. error: A technical identifier for this error.
error_description: A human-readable description of the error. error_description: A human-readable description of the error.
code: The integer error code (an HTTP response code)
""" """
html = self._error_template.render( html = self._error_template.render(
error=error, error_description=error_description error=error, error_description=error_description
) )
respond_with_html(request, 400, html) respond_with_html(request, code, html)
async def get_sso_user_by_remote_user_id( async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str self, auth_provider_id: str, remote_user_id: str