Refactor the CAS handler in prep for using the abstracted SSO code. (#8958)

This makes the CAS handler look more like the SAML/OIDC handlers:

* Render errors to users instead of throwing JSON errors.
* Internal reorganization.
This commit is contained in:
Patrick Cloke 2020-12-18 13:09:45 -05:00 committed by GitHub
parent 56e00ca85e
commit 4218473f9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 162 additions and 69 deletions

1
changelog.d/8958.misc Normal file
View File

@ -0,0 +1 @@
Properly store the mapping of external ID to Matrix ID for CAS users.

View File

@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
You should now have a Django project configured to serve CAS authentication with
a single user created.
## Configure Synapse (and Riot) to use CAS
## Configure Synapse (and Element) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server:
@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
## Testing the configuration
Then in Riot:
Then in Element:
1. Visit the login page with a Riot pointing at your homeserver.
1. Visit the login page with a Element pointing at your homeserver.
2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`.
4. You should be logged in.

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET
import attr
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError
from synapse.api.errors import HttpResponseException
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
@ -29,6 +31,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.
"""
def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
@ -50,6 +72,8 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
self._sso_handler = hs.get_sso_handler()
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@ -69,14 +93,20 @@ class CasHandler:
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
) -> Tuple[str, Optional[str]]:
) -> CasResponse:
"""
Validate a CAS ticket with the server, parse the response, and return the user and display name.
Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.
Raises:
CasError: If there's an error parsing the CAS response.
Returns:
The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
@ -89,43 +119,46 @@ class CasHandler:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
raise CasError("server_error", description) from e
user, attributes = self._parse_cas_response(body)
displayname = attributes.pop(self._cas_displayname_attribute, None)
return self._parse_cas_response(body)
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return user, displayname
def _parse_cas_response(
self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
Raises:
CasError: If there's an error parsing the CAS response.
Returns:
A tuple of the user and a mapping of other attributes.
The parsed CAS response.
"""
user = None
attributes = {}
try:
# Ensure the response is valid.
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
raise CasError(
"missing_service_response",
"root of CAS response is not serviceResponse",
)
success = root[0].tag.endswith("authenticationSuccess")
if not success:
raise CasError("unsucessful_response", "Unsuccessful CAS response")
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@ -139,16 +172,12 @@ class CasHandler:
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
# Ensure a user was found.
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
raise CasError("no_user", "CAS response does not contain user")
return CasResponse(user, attributes)
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
@ -201,7 +230,68 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
try:
cas_response = await self._validate_ticket(ticket, args)
except CasError as e:
logger.exception("Could not validate ticket")
self._sso_handler.render_error(request, e.error, e.error_description, 401)
return
await self._handle_cas_response(
request, cas_response, client_redirect_url, session
)
async def _handle_cas_response(
self,
request: SynapseRequest,
cas_response: CasResponse,
client_redirect_url: Optional[str],
session: Optional[str],
) -> None:
"""Handle a CAS response to a ticket request.
Assumes that the response has been validated. Maps the user onto an MXID,
registering them if necessary, and returns a response to the browser.
Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
cas_response: The parsed CAS response.
client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
session: The session parameter from the `/cas/ticket` HTTP request, if given.
This should be the UI Auth session id.
"""
# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
@ -209,7 +299,7 @@ class CasHandler:
# Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address
cas_response, user_agent, ip_address
)
if session:
@ -225,18 +315,13 @@ class CasHandler:
)
async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
self, cas_response: CasResponse, user_agent: str, ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.
Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
cas_response: The parsed CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
@ -244,15 +329,17 @@ class CasHandler:
The user ID associated with this response.
"""
localpart = map_username_to_mxid_localpart(remote_user_id)
localpart = map_username_to_mxid_localpart(cas_response.username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=display_name,
default_display_name=displayname,
user_agent_ips=[(user_agent, ip_address)],
)

View File

@ -101,7 +101,11 @@ class SsoHandler:
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
def render_error(
self, request, error: str, error_description: Optional[str] = None
self,
request: Request,
error: str,
error_description: Optional[str] = None,
code: int = 400,
) -> None:
"""Renders the error template and responds with it.
@ -113,11 +117,12 @@ class SsoHandler:
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
respond_with_html(request, code, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str