mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-04 19:38:57 -04:00
Refactor the CAS code (move the logic out of the REST layer to a handler) (#7136)
This commit is contained in:
parent
825fb5d0a5
commit
fa4f12102d
5 changed files with 226 additions and 154 deletions
|
@ -14,11 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
|
@ -28,9 +23,10 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.types import UserID
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
def build_service_param(cas_service_url, client_redirect_url):
|
||||
return "%s%s?redirectUrl=%s" % (
|
||||
cas_service_url,
|
||||
"/_matrix/client/r0/login/cas/ticket",
|
||||
urllib.parse.quote(client_redirect_url, safe=""),
|
||||
)
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login$", v1=True)
|
||||
CAS_TYPE = "m.login.cas"
|
||||
|
@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
|
@ -418,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
request.redirect(sso_url)
|
||||
finish_request(request)
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
"""Get the URL to redirect to, to perform SSO auth
|
||||
|
||||
Args:
|
||||
client_redirect_url (bytes): the URL that we should redirect the
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to when everything is done
|
||||
|
||||
Returns:
|
||||
bytes: URL to redirect to
|
||||
URL to redirect to
|
||||
"""
|
||||
# to be implemented by subclasses
|
||||
raise NotImplementedError()
|
||||
|
@ -434,16 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
|
||||
class CasRedirectServlet(BaseSSORedirectServlet):
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
args = urllib.parse.urlencode(
|
||||
{"service": build_service_param(self.cas_service_url, client_redirect_url)}
|
||||
)
|
||||
|
||||
return "%s/login?%s" % (self.cas_server_url, args)
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
return self._cas_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class CasTicketServlet(RestServlet):
|
||||
|
@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(CasTicketServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> None:
|
||||
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
||||
uri = self.cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": parse_string(request, "ticket", required=True),
|
||||
"service": build_service_param(self.cas_service_url, client_redirect_url),
|
||||
}
|
||||
try:
|
||||
body = await self._http_client.get_raw(uri, args)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted raises this error if the connection is closed,
|
||||
# even if that's being used old-http style to signal end-of-data
|
||||
body = pde.response
|
||||
result = await self.handle_cas_response(request, body, client_redirect_url)
|
||||
return result
|
||||
|
||||
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
||||
user, attributes = self.parse_cas_response(cas_response_body)
|
||||
displayname = attributes.pop(self.cas_displayname_attribute, None)
|
||||
|
||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
user, request, client_redirect_url, displayname
|
||||
ticket = parse_string(request, "ticket", required=True)
|
||||
await self._cas_handler.handle_ticket_request(
|
||||
request, client_redirect_url, ticket
|
||||
)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
user = None
|
||||
attributes = {}
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.exception("Error parsing CAS response")
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
|
||||
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
|
|||
def __init__(self, hs):
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class SSOAuthHandler(object):
|
||||
"""
|
||||
Utility class for Resources and Servlets which handle the response from a SSO
|
||||
service
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
# cast to tuple for use with str.startswith
|
||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||
|
||||
async def on_successful_auth(
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
Registers the user if necessary, and then returns a redirect (with
|
||||
a login token) to the client.
|
||||
|
||||
Args:
|
||||
username (unicode|bytes): the remote user id. We'll map this onto
|
||||
something sane for a MXID localpath.
|
||||
|
||||
request (SynapseRequest): the incoming request from the browser. We'll
|
||||
respond to it with a redirect.
|
||||
|
||||
client_redirect_url (unicode): the redirect_url the client gave us when
|
||||
it first started the process.
|
||||
|
||||
user_display_name (unicode|None): if set, and we have to register a new user,
|
||||
we will set their displayname to this.
|
||||
|
||||
Returns:
|
||||
Deferred[none]: Completes once we have handled the request.
|
||||
"""
|
||||
localpart = map_username_to_mxid_localpart(username)
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue