Refactor the CAS code (move the logic out of the REST layer to a handler) (#7136)

This commit is contained in:
Patrick Cloke 2020-03-26 15:05:26 -04:00 committed by GitHub
parent 825fb5d0a5
commit fa4f12102d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 226 additions and 154 deletions

View file

@ -14,11 +14,6 @@
# limitations under the License.
import logging
import xml.etree.ElementTree as ET
from six.moves import urllib
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@ -28,9 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier):
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
def build_service_param(cas_service_url, client_redirect_url):
return "%s%s?redirectUrl=%s" % (
cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.quote(client_redirect_url, safe=""),
)
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
@ -418,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
request.redirect(sso_url)
finish_request(request)
def get_sso_url(self, client_redirect_url):
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
client_redirect_url (bytes): the URL that we should redirect the
client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
bytes: URL to redirect to
URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@ -434,16 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self._cas_handler = hs.get_cas_handler()
def get_sso_url(self, client_redirect_url):
args = urllib.parse.urlencode(
{"service": build_service_param(self.cas_service_url, client_redirect_url)}
)
return "%s/login?%s" % (self.cas_server_url, args)
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
return self._cas_handler.handle_redirect_request(client_redirect_url)
class CasTicketServlet(RestServlet):
@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client()
self._cas_handler = hs.get_cas_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": build_service_param(self.cas_service_url, client_redirect_url),
}
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = await self.handle_cas_response(request, body, client_redirect_url)
return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url, displayname
ticket = parse_string(request, "ticket", required=True)
await self._cas_handler.handle_ticket_request(
request, client_redirect_url, ticket
)
def parse_cas_response(self, cas_response_body):
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
def get_sso_url(self, client_redirect_url):
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.
request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.
user_display_name (unicode|None): if set, and we have to register a new user,
we will set their displayname to this.
Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled: