mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Simplify the flow for SSO UIA (#8881)
* SsoHandler: remove inheritance from BaseHandler * Simplify the flow for SSO UIA We don't need to do all the magic for mapping users when we are doing UIA, so let's factor that out.
This commit is contained in:
parent
025fa06fc7
commit
36ba73f53d
1
changelog.d/8881.misc
Normal file
1
changelog.d/8881.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Simplify logic for handling user-interactive-auth via single-sign-on servers.
|
1
mypy.ini
1
mypy.ini
@ -43,6 +43,7 @@ files =
|
|||||||
synapse/handlers/room_member.py,
|
synapse/handlers/room_member.py,
|
||||||
synapse/handlers/room_member_worker.py,
|
synapse/handlers/room_member_worker.py,
|
||||||
synapse/handlers/saml_handler.py,
|
synapse/handlers/saml_handler.py,
|
||||||
|
synapse/handlers/sso.py,
|
||||||
synapse/handlers/sync.py,
|
synapse/handlers/sync.py,
|
||||||
synapse/handlers/ui_auth,
|
synapse/handlers/ui_auth,
|
||||||
synapse/http/client.py,
|
synapse/http/client.py,
|
||||||
|
@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class BaseHandler:
|
class BaseHandler:
|
||||||
"""
|
"""
|
||||||
Common base class for the event handlers.
|
Common base class for the event handlers.
|
||||||
|
|
||||||
|
Deprecated: new code should not use this. Instead, Handler classes should define the
|
||||||
|
fields they actually need. The utility methods should either be factored out to
|
||||||
|
standalone helper functions, or to different Handler classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
@ -36,6 +36,8 @@ import attr
|
|||||||
import bcrypt
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
@ -1331,15 +1333,14 @@ class AuthHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def complete_sso_ui_auth(
|
async def complete_sso_ui_auth(
|
||||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
self, registered_user_id: str, session_id: str, request: Request,
|
||||||
):
|
):
|
||||||
"""Having figured out a mxid for this user, complete the HTTP request
|
"""Having figured out a mxid for this user, complete the HTTP request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
registered_user_id: The registered user ID to complete SSO login for.
|
registered_user_id: The registered user ID to complete SSO login for.
|
||||||
|
session_id: The ID of the user-interactive auth session.
|
||||||
request: The request to complete.
|
request: The request to complete.
|
||||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
|
||||||
process.
|
|
||||||
"""
|
"""
|
||||||
# Mark the stage of the authentication as successful.
|
# Mark the stage of the authentication as successful.
|
||||||
# Save the user who authenticated with SSO, this will be used to ensure
|
# Save the user who authenticated with SSO, this will be used to ensure
|
||||||
@ -1355,7 +1356,7 @@ class AuthHandler(BaseHandler):
|
|||||||
async def complete_sso_login(
|
async def complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
request: SynapseRequest,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
):
|
):
|
||||||
@ -1383,7 +1384,7 @@ class AuthHandler(BaseHandler):
|
|||||||
def _complete_sso_login(
|
def _complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
request: SynapseRequest,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
):
|
):
|
||||||
|
@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
|
|||||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# first check if we're doing a UIA
|
||||||
|
if ui_auth_session_id:
|
||||||
|
try:
|
||||||
|
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Could not extract remote user id")
|
||||||
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
|
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# otherwise, it's a login
|
||||||
|
|
||||||
# Pull out the user-agent and IP from the request.
|
# Pull out the user-agent and IP from the request.
|
||||||
user_agent = request.get_user_agent("")
|
user_agent = request.get_user_agent("")
|
||||||
ip_address = self.hs.get_ip_from_request(request)
|
ip_address = self.hs.get_ip_from_request(request)
|
||||||
@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
|
|||||||
extra_attributes = await get_extra_attributes(userinfo, token)
|
extra_attributes = await get_extra_attributes(userinfo, token)
|
||||||
|
|
||||||
# and finally complete the login
|
# and finally complete the login
|
||||||
if ui_auth_session_id:
|
await self._auth_handler.complete_sso_login(
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
user_id, request, client_redirect_url, extra_attributes
|
||||||
user_id, ui_auth_session_id, request
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self._auth_handler.complete_sso_login(
|
|
||||||
user_id, request, client_redirect_url, extra_attributes
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_oidc_session_token(
|
def _generate_oidc_session_token(
|
||||||
self,
|
self,
|
||||||
@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
|
|||||||
The mxid of the user
|
The mxid of the user
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise MappingException(
|
raise MappingException(
|
||||||
"Failed to extract subject from OIDC response: %s" % (e,)
|
"Failed to extract subject from OIDC response: %s" % (e,)
|
||||||
)
|
)
|
||||||
# Some OIDC providers use integer IDs, but Synapse expects external IDs
|
|
||||||
# to be strings.
|
|
||||||
remote_user_id = str(remote_user_id)
|
|
||||||
|
|
||||||
# Older mapping providers don't accept the `failures` argument, so we
|
# Older mapping providers don't accept the `failures` argument, so we
|
||||||
# try and detect support.
|
# try and detect support.
|
||||||
@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
|
|||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||||
|
"""Extract the unique remote id from an OIDC UserInfo block
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userinfo: An object representing the user given by the OIDC provider
|
||||||
|
Returns:
|
||||||
|
remote user id
|
||||||
|
"""
|
||||||
|
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
||||||
|
# Some OIDC providers use integer IDs, but Synapse expects external IDs
|
||||||
|
# to be strings.
|
||||||
|
return str(remote_user_id)
|
||||||
|
|
||||||
|
|
||||||
UserAttributeDict = TypedDict(
|
UserAttributeDict = TypedDict(
|
||||||
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
||||||
|
@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
|
|||||||
saml2_auth.in_response_to, None
|
saml2_auth.in_response_to, None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# first check if we're doing a UIA
|
||||||
|
if current_session and current_session.ui_auth_session_id:
|
||||||
|
try:
|
||||||
|
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
|
||||||
|
except MappingException as e:
|
||||||
|
logger.exception("Failed to extract remote user id from SAML response")
|
||||||
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
|
self._auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
current_session.ui_auth_session_id,
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# otherwise, we're handling a login request.
|
||||||
|
|
||||||
# Ensure that the attributes of the logged in user meet the required
|
# Ensure that the attributes of the logged in user meet the required
|
||||||
# attributes.
|
# attributes.
|
||||||
for requirement in self._saml2_attribute_requirements:
|
for requirement in self._saml2_attribute_requirements:
|
||||||
@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
|
|||||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Complete the interactive auth session or the login.
|
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||||
if current_session and current_session.ui_auth_session_id:
|
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
|
||||||
user_id, current_session.ui_auth_session_id, request
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
|
||||||
|
|
||||||
async def _map_saml_response_to_user(
|
async def _map_saml_response_to_user(
|
||||||
self,
|
self,
|
||||||
@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
|
|||||||
RedirectException: some mapping providers may raise this if they need
|
RedirectException: some mapping providers may raise this if they need
|
||||||
to redirect to an interstitial page.
|
to redirect to an interstitial page.
|
||||||
"""
|
"""
|
||||||
|
remote_user_id = self._remote_id_from_saml_response(
|
||||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
|
||||||
saml2_auth, client_redirect_url
|
saml2_auth, client_redirect_url
|
||||||
)
|
)
|
||||||
|
|
||||||
if not remote_user_id:
|
|
||||||
raise MappingException(
|
|
||||||
"Failed to extract remote user id from SAML response"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def saml_response_to_remapped_user_attributes(
|
async def saml_response_to_remapped_user_attributes(
|
||||||
failures: int,
|
failures: int,
|
||||||
) -> UserAttributes:
|
) -> UserAttributes:
|
||||||
@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
|
|||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remote_id_from_saml_response(
|
||||||
|
self,
|
||||||
|
saml2_auth: saml2.response.AuthnResponse,
|
||||||
|
client_redirect_url: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
"""Extract the unique remote id from a SAML2 AuthnResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saml2_auth: The parsed SAML2 response.
|
||||||
|
client_redirect_url: The redirect URL passed in by the client.
|
||||||
|
Returns:
|
||||||
|
remote user id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MappingException if there was an error extracting the user id
|
||||||
|
"""
|
||||||
|
# It's not obvious why we need to pass in the redirect URI to the mapping
|
||||||
|
# provider, but we do :/
|
||||||
|
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
||||||
|
saml2_auth, client_redirect_url
|
||||||
|
)
|
||||||
|
|
||||||
|
if not remote_user_id:
|
||||||
|
raise MappingException(
|
||||||
|
"Failed to extract remote user id from SAML response"
|
||||||
|
)
|
||||||
|
|
||||||
|
return remote_user_id
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||||
to_expire = set()
|
to_expire = set()
|
||||||
|
@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
|
|||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
from synapse.handlers._base import BaseHandler
|
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.types import UserID, contains_invalid_mxid_characters
|
from synapse.types import UserID, contains_invalid_mxid_characters
|
||||||
|
|
||||||
@ -42,14 +43,16 @@ class UserAttributes:
|
|||||||
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
||||||
|
|
||||||
|
|
||||||
class SsoHandler(BaseHandler):
|
class SsoHandler:
|
||||||
# The number of attempts to ask the mapping provider for when generating an MXID.
|
# The number of attempts to ask the mapping provider for when generating an MXID.
|
||||||
_MAP_USERNAME_RETRIES = 1000
|
_MAP_USERNAME_RETRIES = 1000
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
self._store = hs.get_datastore()
|
||||||
|
self._server_name = hs.hostname
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._error_template = hs.config.sso_error_template
|
self._error_template = hs.config.sso_error_template
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
def render_error(
|
def render_error(
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
self, request, error: str, error_description: Optional[str] = None
|
||||||
@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if we already have a mapping for this user.
|
# Check if we already have a mapping for this user.
|
||||||
previously_registered_user_id = await self.store.get_user_by_external_id(
|
previously_registered_user_id = await self._store.get_user_by_external_id(
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id, remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
|
|||||||
previously_registered_user_id = await grandfather_existing_users()
|
previously_registered_user_id = await grandfather_existing_users()
|
||||||
if previously_registered_user_id:
|
if previously_registered_user_id:
|
||||||
# Future logins should also match this user ID.
|
# Future logins should also match this user ID.
|
||||||
await self.store.record_user_external_id(
|
await self._store.record_user_external_id(
|
||||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||||
)
|
)
|
||||||
return previously_registered_user_id
|
return previously_registered_user_id
|
||||||
@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if this mxid already exists
|
# Check if this mxid already exists
|
||||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
user_id = UserID(attributes.localpart, self._server_name).to_string()
|
||||||
if not await self.store.get_users_by_id_case_insensitive(user_id):
|
if not await self._store.get_users_by_id_case_insensitive(user_id):
|
||||||
# This mxid is free
|
# This mxid is free
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
|
|||||||
user_agent_ips=[(user_agent, ip_address)],
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.record_user_external_id(
|
await self._store.record_user_external_id(
|
||||||
auth_provider_id, remote_user_id, registered_user_id
|
auth_provider_id, remote_user_id, registered_user_id
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
|
async def complete_sso_ui_auth_request(
|
||||||
|
self,
|
||||||
|
auth_provider_id: str,
|
||||||
|
remote_user_id: str,
|
||||||
|
ui_auth_session_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Given an SSO ID, retrieve the user ID for it and complete UIA.
|
||||||
|
|
||||||
|
Note that this requires that the user is mapped in the "user_external_ids"
|
||||||
|
table. This will be the case if they have ever logged in via SAML or OIDC in
|
||||||
|
recentish synapse versions, but may not be for older users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
|
"oidc" or "saml".
|
||||||
|
remote_user_id: The unique identifier from the SSO provider.
|
||||||
|
ui_auth_session_id: The ID of the user-interactive auth session.
|
||||||
|
request: The request to complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
|
auth_provider_id, remote_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
logger.warning(
|
||||||
|
"Remote user %s/%s has not previously logged in here: UIA will fail",
|
||||||
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
)
|
||||||
|
# Let the UIA flow handle this the same as if they presented creds for a
|
||||||
|
# different user.
|
||||||
|
user_id = ""
|
||||||
|
|
||||||
|
await self._auth_handler.complete_sso_ui_auth(
|
||||||
|
user_id, ui_auth_session_id, request
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user