mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-06 11:55:04 -04:00
Abstract shared SSO code. (#8765)
De-duplicates code between the SAML and OIDC implementations.
This commit is contained in:
parent
e487d9fabc
commit
ee382025b0
6 changed files with 159 additions and 120 deletions
|
@ -24,7 +24,8 @@ from saml2.client import Saml2Client
|
|||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.saml2_config import SamlAttributeRequirement
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.module_api import ModuleApi
|
||||
|
@ -42,10 +43,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MappingException(Exception):
|
||||
"""Used to catch errors when mapping the SAML2 response to a user."""
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
|
@ -57,17 +54,13 @@ class Saml2SessionData:
|
|||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||
|
||||
|
||||
class SamlHandler:
|
||||
class SamlHandler(BaseHandler):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.hs = hs
|
||||
super().__init__(hs)
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self._datastore = hs.get_datastore()
|
||||
self._hostname = hs.hostname
|
||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||
self._grandfathered_mxid_source_attribute = (
|
||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||
|
@ -88,26 +81,9 @@ class SamlHandler:
|
|||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||
|
||||
# a lock on the mappings
|
||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
|
||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
|
||||
|
||||
def _render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
) -> None:
|
||||
"""Render the error template and respond to the request with it.
|
||||
|
||||
This is used to show errors to the user. The template of this page can
|
||||
be found under `synapse/res/templates/sso_error.html`.
|
||||
|
||||
Args:
|
||||
request: The incoming request from the browser.
|
||||
We'll respond with an HTML page describing the error.
|
||||
error: A technical identifier for this error.
|
||||
error_description: A human-readable description of the error.
|
||||
"""
|
||||
html = self._error_template.render(
|
||||
error=error, error_description=error_description
|
||||
)
|
||||
respond_with_html(request, 400, html)
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
def handle_redirect_request(
|
||||
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
|
||||
|
@ -130,7 +106,7 @@ class SamlHandler:
|
|||
# Since SAML sessions timeout it is useful to log when they were created.
|
||||
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
||||
|
||||
now = self._clock.time_msec()
|
||||
now = self.clock.time_msec()
|
||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
|
@ -171,12 +147,12 @@ class SamlHandler:
|
|||
# in the (user-visible) exception message, so let's log the exception here
|
||||
# so we can track down the session IDs later.
|
||||
logger.warning(str(e))
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request, "unsolicited_response", "Unexpected SAML2 login."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request,
|
||||
"invalid_response",
|
||||
"Unable to parse SAML2 response: %s." % (e,),
|
||||
|
@ -184,7 +160,7 @@ class SamlHandler:
|
|||
return
|
||||
|
||||
if saml2_auth.not_signed:
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request, "unsigned_respond", "SAML2 response was not signed."
|
||||
)
|
||||
return
|
||||
|
@ -210,7 +186,7 @@ class SamlHandler:
|
|||
# attributes.
|
||||
for requirement in self._saml2_attribute_requirements:
|
||||
if not _check_attribute_requirement(saml2_auth.ava, requirement):
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request, "unauthorised", "You are not authorised to log in here."
|
||||
)
|
||||
return
|
||||
|
@ -226,7 +202,7 @@ class SamlHandler:
|
|||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._render_error(request, "mapping_error", str(e))
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
# Complete the interactive auth session or the login.
|
||||
|
@ -274,17 +250,11 @@ class SamlHandler:
|
|||
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
# first of all, check if we already have a mapping for this user
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||
self._auth_provider_id, remote_user_id,
|
||||
)
|
||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
||||
self._auth_provider_id, remote_user_id
|
||||
)
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
if previously_registered_user_id:
|
||||
return previously_registered_user_id
|
||||
|
||||
# backwards-compatibility hack: see if there is an existing user with a
|
||||
# suitable mapping from the uid
|
||||
|
@ -294,7 +264,7 @@ class SamlHandler:
|
|||
):
|
||||
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
||||
user_id = UserID(
|
||||
map_username_to_mxid_localpart(attrval), self._hostname
|
||||
map_username_to_mxid_localpart(attrval), self.server_name
|
||||
).to_string()
|
||||
logger.info(
|
||||
"Looking for existing account based on mapped %s %s",
|
||||
|
@ -302,11 +272,11 @@ class SamlHandler:
|
|||
user_id,
|
||||
)
|
||||
|
||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
registered_user_id = list(users.keys())[0]
|
||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||
await self._datastore.record_user_external_id(
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
@ -335,8 +305,8 @@ class SamlHandler:
|
|||
emails = attribute_dict.get("emails", [])
|
||||
|
||||
# Check if this mxid already exists
|
||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self._hostname).to_string()
|
||||
if not await self.store.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self.server_name).to_string()
|
||||
):
|
||||
# This mxid is free
|
||||
break
|
||||
|
@ -348,7 +318,6 @@ class SamlHandler:
|
|||
)
|
||||
|
||||
logger.info("Mapped SAML user to local part %s", localpart)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=displayname,
|
||||
|
@ -356,13 +325,13 @@ class SamlHandler:
|
|||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
def expire_sessions(self):
|
||||
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|
||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||
to_expire = set()
|
||||
for reqid, data in self._outstanding_requests_dict.items():
|
||||
if data.creation_time < expire_before:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue