mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-04 03:14:48 -04:00
Abstract shared SSO code. (#8765)
De-duplicates code between the SAML and OIDC implementations.
This commit is contained in:
parent
e487d9fabc
commit
ee382025b0
6 changed files with 159 additions and 120 deletions
|
@ -34,7 +34,8 @@ from typing_extensions import TypedDict
|
|||
from twisted.web.client import readBody
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
|
@ -83,17 +84,12 @@ class OidcError(Exception):
|
|||
return self.error
|
||||
|
||||
|
||||
class MappingException(Exception):
|
||||
"""Used to catch errors when mapping the UserInfo object
|
||||
"""
|
||||
|
||||
|
||||
class OidcHandler:
|
||||
class OidcHandler(BaseHandler):
|
||||
"""Handles requests related to the OpenID Connect login flow.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
super().__init__(hs)
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
||||
|
@ -120,36 +116,13 @@ class OidcHandler:
|
|||
self._http_client = hs.get_proxied_http_client()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._datastore = hs.get_datastore()
|
||||
self._clock = hs.get_clock()
|
||||
self._hostname = hs.hostname # type: str
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
self._error_template = hs.config.sso_error_template
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "oidc"
|
||||
|
||||
def _render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
) -> None:
|
||||
"""Render the error template and respond to the request with it.
|
||||
|
||||
This is used to show errors to the user. The template of this page can
|
||||
be found under `synapse/res/templates/sso_error.html`.
|
||||
|
||||
Args:
|
||||
request: The incoming request from the browser.
|
||||
We'll respond with an HTML page describing the error.
|
||||
error: A technical identifier for this error. Those include
|
||||
well-known OAuth2/OIDC error types like invalid_request or
|
||||
access_denied.
|
||||
error_description: A human-readable description of the error.
|
||||
"""
|
||||
html = self._error_template.render(
|
||||
error=error, error_description=error_description
|
||||
)
|
||||
respond_with_html(request, 400, html)
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
def _validate_metadata(self):
|
||||
"""Verifies the provider metadata.
|
||||
|
@ -571,7 +544,7 @@ class OidcHandler:
|
|||
|
||||
Since we might want to display OIDC-related errors in a user-friendly
|
||||
way, we don't raise SynapseError from here. Instead, we call
|
||||
``self._render_error`` which displays an HTML page for the error.
|
||||
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||
|
||||
Most of the OpenID Connect logic happens here:
|
||||
|
||||
|
@ -609,7 +582,7 @@ class OidcHandler:
|
|||
if error != "access_denied":
|
||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||
|
||||
self._render_error(request, error, description)
|
||||
self._sso_handler.render_error(request, error, description)
|
||||
return
|
||||
|
||||
# otherwise, it is presumably a successful response. see:
|
||||
|
@ -619,7 +592,9 @@ class OidcHandler:
|
|||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||
if session is None:
|
||||
logger.info("No session cookie found")
|
||||
self._render_error(request, "missing_session", "No session cookie found")
|
||||
self._sso_handler.render_error(
|
||||
request, "missing_session", "No session cookie found"
|
||||
)
|
||||
return
|
||||
|
||||
# Remove the cookie. There is a good chance that if the callback failed
|
||||
|
@ -637,7 +612,9 @@ class OidcHandler:
|
|||
# Check for the state query parameter
|
||||
if b"state" not in request.args:
|
||||
logger.info("State parameter is missing")
|
||||
self._render_error(request, "invalid_request", "State parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "State parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
state = request.args[b"state"][0].decode()
|
||||
|
@ -651,17 +628,19 @@ class OidcHandler:
|
|||
) = self._verify_oidc_session_token(session, state)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._render_error(request, "invalid_session", str(e))
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
except MacaroonInvalidSignatureException as e:
|
||||
logger.exception("Could not verify session")
|
||||
self._render_error(request, "mismatching_session", str(e))
|
||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
# Exchange the code with the provider
|
||||
if b"code" not in request.args:
|
||||
logger.info("Code parameter is missing")
|
||||
self._render_error(request, "invalid_request", "Code parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "Code parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("Exchanging code")
|
||||
|
@ -670,7 +649,7 @@ class OidcHandler:
|
|||
token = await self._exchange_code(code)
|
||||
except OidcError as e:
|
||||
logger.exception("Could not exchange code")
|
||||
self._render_error(request, e.error, e.error_description)
|
||||
self._sso_handler.render_error(request, e.error, e.error_description)
|
||||
return
|
||||
|
||||
logger.debug("Successfully obtained OAuth2 access token")
|
||||
|
@ -683,7 +662,7 @@ class OidcHandler:
|
|||
userinfo = await self._fetch_userinfo(token)
|
||||
except Exception as e:
|
||||
logger.exception("Could not fetch userinfo")
|
||||
self._render_error(request, "fetch_error", str(e))
|
||||
self._sso_handler.render_error(request, "fetch_error", str(e))
|
||||
return
|
||||
else:
|
||||
logger.debug("Extracting userinfo from id_token")
|
||||
|
@ -691,7 +670,7 @@ class OidcHandler:
|
|||
userinfo = await self._parse_id_token(token, nonce=nonce)
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._render_error(request, "invalid_token", str(e))
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
|
||||
# Pull out the user-agent and IP from the request.
|
||||
|
@ -705,7 +684,7 @@ class OidcHandler:
|
|||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._render_error(request, "mapping_error", str(e))
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
# Mapping providers might not have get_extra_attributes: only call this
|
||||
|
@ -770,7 +749,7 @@ class OidcHandler:
|
|||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
||||
)
|
||||
now = self._clock.time_msec()
|
||||
now = self.clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
|
||||
|
@ -845,7 +824,7 @@ class OidcHandler:
|
|||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self._clock.time_msec()
|
||||
now = self.clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _map_userinfo_to_user(
|
||||
|
@ -885,20 +864,14 @@ class OidcHandler:
|
|||
# to be strings.
|
||||
remote_user_id = str(remote_user_id)
|
||||
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
|
||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
||||
# first of all, check if we already have a mapping for this user
|
||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||
self._auth_provider_id, remote_user_id,
|
||||
)
|
||||
if previously_registered_user_id:
|
||||
return previously_registered_user_id
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
|
||||
# Otherwise, generate a new user.
|
||||
try:
|
||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||
userinfo, token
|
||||
|
@ -917,8 +890,8 @@ class OidcHandler:
|
|||
|
||||
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
||||
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
||||
user_id = UserID(localpart, self.server_name).to_string()
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
if self._allow_existing_users:
|
||||
if len(users) == 1:
|
||||
|
@ -942,7 +915,8 @@ class OidcHandler:
|
|||
default_display_name=attributes["display_name"],
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
await self._datastore.record_user_external_id(
|
||||
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||
)
|
||||
return registered_user_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue