mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Refactor the CAS handler in prep for using the abstracted SSO code. (#8958)
This makes the CAS handler look more like the SAML/OIDC handlers: * Render errors to users instead of throwing JSON errors. * Internal reorganization.
This commit is contained in:
parent
56e00ca85e
commit
4218473f9e
1
changelog.d/8958.misc
Normal file
1
changelog.d/8958.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Properly store the mapping of external ID to Matrix ID for CAS users.
|
@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
|
|||||||
You should now have a Django project configured to serve CAS authentication with
|
You should now have a Django project configured to serve CAS authentication with
|
||||||
a single user created.
|
a single user created.
|
||||||
|
|
||||||
## Configure Synapse (and Riot) to use CAS
|
## Configure Synapse (and Element) to use CAS
|
||||||
|
|
||||||
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
|
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
|
||||||
running Django test server:
|
running Django test server:
|
||||||
@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
|
|||||||
|
|
||||||
## Testing the configuration
|
## Testing the configuration
|
||||||
|
|
||||||
Then in Riot:
|
Then in Element:
|
||||||
|
|
||||||
1. Visit the login page with a Riot pointing at your homeserver.
|
1. Visit the login page with a Element pointing at your homeserver.
|
||||||
2. Click the Single Sign-On button.
|
2. Click the Single Sign-On button.
|
||||||
3. Login using the credentials created with `createsuperuser`.
|
3. Login using the credentials created with `createsuperuser`.
|
||||||
4. You should be logged in.
|
4. You should be logged in.
|
||||||
|
@ -13,13 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError
|
from synapse.api.errors import HttpResponseException
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||||
|
|
||||||
@ -29,6 +31,26 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CasError(Exception):
|
||||||
|
"""Used to catch errors when validating the CAS ticket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, error, error_description=None):
|
||||||
|
self.error = error
|
||||||
|
self.error_description = error_description
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.error_description:
|
||||||
|
return "{}: {}".format(self.error, self.error_description)
|
||||||
|
return self.error
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class CasResponse:
|
||||||
|
username = attr.ib(type=str)
|
||||||
|
attributes = attr.ib(type=Dict[str, Optional[str]])
|
||||||
|
|
||||||
|
|
||||||
class CasHandler:
|
class CasHandler:
|
||||||
"""
|
"""
|
||||||
Utility class for to handle the response from a CAS SSO service.
|
Utility class for to handle the response from a CAS SSO service.
|
||||||
@ -50,6 +72,8 @@ class CasHandler:
|
|||||||
|
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
def _build_service_param(self, args: Dict[str, str]) -> str:
|
def _build_service_param(self, args: Dict[str, str]) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a value to use as the "service" parameter when redirecting or
|
Generates a value to use as the "service" parameter when redirecting or
|
||||||
@ -69,14 +93,20 @@ class CasHandler:
|
|||||||
|
|
||||||
async def _validate_ticket(
|
async def _validate_ticket(
|
||||||
self, ticket: str, service_args: Dict[str, str]
|
self, ticket: str, service_args: Dict[str, str]
|
||||||
) -> Tuple[str, Optional[str]]:
|
) -> CasResponse:
|
||||||
"""
|
"""
|
||||||
Validate a CAS ticket with the server, parse the response, and return the user and display name.
|
Validate a CAS ticket with the server, and return the parsed the response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ticket: The CAS ticket from the client.
|
ticket: The CAS ticket from the client.
|
||||||
service_args: Additional arguments to include in the service URL.
|
service_args: Additional arguments to include in the service URL.
|
||||||
Should be the same as those passed to `get_redirect_url`.
|
Should be the same as those passed to `get_redirect_url`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CasError: If there's an error parsing the CAS response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed CAS response.
|
||||||
"""
|
"""
|
||||||
uri = self._cas_server_url + "/proxyValidate"
|
uri = self._cas_server_url + "/proxyValidate"
|
||||||
args = {
|
args = {
|
||||||
@ -89,43 +119,46 @@ class CasHandler:
|
|||||||
# Twisted raises this error if the connection is closed,
|
# Twisted raises this error if the connection is closed,
|
||||||
# even if that's being used old-http style to signal end-of-data
|
# even if that's being used old-http style to signal end-of-data
|
||||||
body = pde.response
|
body = pde.response
|
||||||
|
except HttpResponseException as e:
|
||||||
|
description = (
|
||||||
|
(
|
||||||
|
'Authorization server responded with a "{status}" error '
|
||||||
|
"while exchanging the authorization code."
|
||||||
|
).format(status=e.code),
|
||||||
|
)
|
||||||
|
raise CasError("server_error", description) from e
|
||||||
|
|
||||||
user, attributes = self._parse_cas_response(body)
|
return self._parse_cas_response(body)
|
||||||
displayname = attributes.pop(self._cas_displayname_attribute, None)
|
|
||||||
|
|
||||||
for required_attribute, required_value in self._cas_required_attributes.items():
|
def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
|
||||||
# If required attribute was not in CAS Response - Forbidden
|
|
||||||
if required_attribute not in attributes:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
# Also need to check value
|
|
||||||
if required_value is not None:
|
|
||||||
actual_value = attributes[required_attribute]
|
|
||||||
# If required attribute value does not match expected - Forbidden
|
|
||||||
if required_value != actual_value:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
return user, displayname
|
|
||||||
|
|
||||||
def _parse_cas_response(
|
|
||||||
self, cas_response_body: bytes
|
|
||||||
) -> Tuple[str, Dict[str, Optional[str]]]:
|
|
||||||
"""
|
"""
|
||||||
Retrieve the user and other parameters from the CAS response.
|
Retrieve the user and other parameters from the CAS response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cas_response_body: The response from the CAS query.
|
cas_response_body: The response from the CAS query.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CasError: If there's an error parsing the CAS response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of the user and a mapping of other attributes.
|
The parsed CAS response.
|
||||||
"""
|
"""
|
||||||
user = None
|
|
||||||
attributes = {}
|
# Ensure the response is valid.
|
||||||
try:
|
|
||||||
root = ET.fromstring(cas_response_body)
|
root = ET.fromstring(cas_response_body)
|
||||||
if not root.tag.endswith("serviceResponse"):
|
if not root.tag.endswith("serviceResponse"):
|
||||||
raise Exception("root of CAS response is not serviceResponse")
|
raise CasError(
|
||||||
|
"missing_service_response",
|
||||||
|
"root of CAS response is not serviceResponse",
|
||||||
|
)
|
||||||
|
|
||||||
success = root[0].tag.endswith("authenticationSuccess")
|
success = root[0].tag.endswith("authenticationSuccess")
|
||||||
|
if not success:
|
||||||
|
raise CasError("unsucessful_response", "Unsuccessful CAS response")
|
||||||
|
|
||||||
|
# Iterate through the nodes and pull out the user and any extra attributes.
|
||||||
|
user = None
|
||||||
|
attributes = {}
|
||||||
for child in root[0]:
|
for child in root[0]:
|
||||||
if child.tag.endswith("user"):
|
if child.tag.endswith("user"):
|
||||||
user = child.text
|
user = child.text
|
||||||
@ -139,16 +172,12 @@ class CasHandler:
|
|||||||
if "}" in tag:
|
if "}" in tag:
|
||||||
tag = tag.split("}")[1]
|
tag = tag.split("}")[1]
|
||||||
attributes[tag] = attribute.text
|
attributes[tag] = attribute.text
|
||||||
|
|
||||||
|
# Ensure a user was found.
|
||||||
if user is None:
|
if user is None:
|
||||||
raise Exception("CAS response does not contain user")
|
raise CasError("no_user", "CAS response does not contain user")
|
||||||
except Exception:
|
|
||||||
logger.exception("Error parsing CAS response")
|
return CasResponse(user, attributes)
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
if not success:
|
|
||||||
raise LoginError(
|
|
||||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
|
||||||
)
|
|
||||||
return user, attributes
|
|
||||||
|
|
||||||
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
|
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
|
||||||
"""
|
"""
|
||||||
@ -201,7 +230,68 @@ class CasHandler:
|
|||||||
args["redirectUrl"] = client_redirect_url
|
args["redirectUrl"] = client_redirect_url
|
||||||
if session:
|
if session:
|
||||||
args["session"] = session
|
args["session"] = session
|
||||||
username, user_display_name = await self._validate_ticket(ticket, args)
|
|
||||||
|
try:
|
||||||
|
cas_response = await self._validate_ticket(ticket, args)
|
||||||
|
except CasError as e:
|
||||||
|
logger.exception("Could not validate ticket")
|
||||||
|
self._sso_handler.render_error(request, e.error, e.error_description, 401)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._handle_cas_response(
|
||||||
|
request, cas_response, client_redirect_url, session
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_cas_response(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
cas_response: CasResponse,
|
||||||
|
client_redirect_url: Optional[str],
|
||||||
|
session: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Handle a CAS response to a ticket request.
|
||||||
|
|
||||||
|
Assumes that the response has been validated. Maps the user onto an MXID,
|
||||||
|
registering them if necessary, and returns a response to the browser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: the incoming request from the browser. We'll respond to it with an
|
||||||
|
HTML page or a redirect
|
||||||
|
|
||||||
|
cas_response: The parsed CAS response.
|
||||||
|
|
||||||
|
client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
|
||||||
|
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
|
||||||
|
|
||||||
|
session: The session parameter from the `/cas/ticket` HTTP request, if given.
|
||||||
|
This should be the UI Auth session id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Ensure that the attributes of the logged in user meet the required
|
||||||
|
# attributes.
|
||||||
|
for required_attribute, required_value in self._cas_required_attributes.items():
|
||||||
|
# If required attribute was not in CAS Response - Forbidden
|
||||||
|
if required_attribute not in cas_response.attributes:
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request,
|
||||||
|
"unauthorised",
|
||||||
|
"You are not authorised to log in here.",
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Also need to check value
|
||||||
|
if required_value is not None:
|
||||||
|
actual_value = cas_response.attributes[required_attribute]
|
||||||
|
# If required attribute value does not match expected - Forbidden
|
||||||
|
if required_value != actual_value:
|
||||||
|
self._sso_handler.render_error(
|
||||||
|
request,
|
||||||
|
"unauthorised",
|
||||||
|
"You are not authorised to log in here.",
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Pull out the user-agent and IP from the request.
|
# Pull out the user-agent and IP from the request.
|
||||||
user_agent = request.get_user_agent("")
|
user_agent = request.get_user_agent("")
|
||||||
@ -209,7 +299,7 @@ class CasHandler:
|
|||||||
|
|
||||||
# Get the matrix ID from the CAS username.
|
# Get the matrix ID from the CAS username.
|
||||||
user_id = await self._map_cas_user_to_matrix_user(
|
user_id = await self._map_cas_user_to_matrix_user(
|
||||||
username, user_display_name, user_agent, ip_address
|
cas_response, user_agent, ip_address
|
||||||
)
|
)
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
@ -225,18 +315,13 @@ class CasHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _map_cas_user_to_matrix_user(
|
async def _map_cas_user_to_matrix_user(
|
||||||
self,
|
self, cas_response: CasResponse, user_agent: str, ip_address: str,
|
||||||
remote_user_id: str,
|
|
||||||
display_name: Optional[str],
|
|
||||||
user_agent: str,
|
|
||||||
ip_address: str,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Given a CAS username, retrieve the user ID for it and possibly register the user.
|
Given a CAS username, retrieve the user ID for it and possibly register the user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
remote_user_id: The username from the CAS response.
|
cas_response: The parsed CAS response.
|
||||||
display_name: The display name from the CAS response.
|
|
||||||
user_agent: The user agent of the client making the request.
|
user_agent: The user agent of the client making the request.
|
||||||
ip_address: The IP address of the client making the request.
|
ip_address: The IP address of the client making the request.
|
||||||
|
|
||||||
@ -244,15 +329,17 @@ class CasHandler:
|
|||||||
The user ID associated with this response.
|
The user ID associated with this response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
localpart = map_username_to_mxid_localpart(remote_user_id)
|
localpart = map_username_to_mxid_localpart(cas_response.username)
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||||
|
|
||||||
|
displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
|
||||||
|
|
||||||
# If the user does not exist, register it.
|
# If the user does not exist, register it.
|
||||||
if not registered_user_id:
|
if not registered_user_id:
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=display_name,
|
default_display_name=displayname,
|
||||||
user_agent_ips=[(user_agent, ip_address)],
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,7 +101,11 @@ class SsoHandler:
|
|||||||
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
||||||
|
|
||||||
def render_error(
|
def render_error(
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
self,
|
||||||
|
request: Request,
|
||||||
|
error: str,
|
||||||
|
error_description: Optional[str] = None,
|
||||||
|
code: int = 400,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Renders the error template and responds with it.
|
"""Renders the error template and responds with it.
|
||||||
|
|
||||||
@ -113,11 +117,12 @@ class SsoHandler:
|
|||||||
We'll respond with an HTML page describing the error.
|
We'll respond with an HTML page describing the error.
|
||||||
error: A technical identifier for this error.
|
error: A technical identifier for this error.
|
||||||
error_description: A human-readable description of the error.
|
error_description: A human-readable description of the error.
|
||||||
|
code: The integer error code (an HTTP response code)
|
||||||
"""
|
"""
|
||||||
html = self._error_template.render(
|
html = self._error_template.render(
|
||||||
error=error, error_description=error_description
|
error=error, error_description=error_description
|
||||||
)
|
)
|
||||||
respond_with_html(request, 400, html)
|
respond_with_html(request, code, html)
|
||||||
|
|
||||||
async def get_sso_user_by_remote_user_id(
|
async def get_sso_user_by_remote_user_id(
|
||||||
self, auth_provider_id: str, remote_user_id: str
|
self, auth_provider_id: str, remote_user_id: str
|
||||||
|
Loading…
Reference in New Issue
Block a user