mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-20 16:24:09 -04:00
Abstract shared SSO code. (#8765)
De-duplicates code between the SAML and OIDC implementations.
This commit is contained in:
parent
e487d9fabc
commit
ee382025b0
6 changed files with 159 additions and 120 deletions
1
changelog.d/8765.misc
Normal file
1
changelog.d/8765.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Consolidate logic between the OpenID Connect and SAML code.
|
|
@ -34,7 +34,8 @@ from typing_extensions import TypedDict
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.handlers._base import BaseHandler
|
||||||
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
|
@ -83,17 +84,12 @@ class OidcError(Exception):
|
||||||
return self.error
|
return self.error
|
||||||
|
|
||||||
|
|
||||||
class MappingException(Exception):
|
class OidcHandler(BaseHandler):
|
||||||
"""Used to catch errors when mapping the UserInfo object
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class OidcHandler:
|
|
||||||
"""Handles requests related to the OpenID Connect login flow.
|
"""Handles requests related to the OpenID Connect login flow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
super().__init__(hs)
|
||||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
||||||
|
@ -120,36 +116,13 @@ class OidcHandler:
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._datastore = hs.get_datastore()
|
|
||||||
self._clock = hs.get_clock()
|
|
||||||
self._hostname = hs.hostname # type: str
|
|
||||||
self._server_name = hs.config.server_name # type: str
|
self._server_name = hs.config.server_name # type: str
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
self._error_template = hs.config.sso_error_template
|
|
||||||
|
|
||||||
# identifier for the external_ids table
|
# identifier for the external_ids table
|
||||||
self._auth_provider_id = "oidc"
|
self._auth_provider_id = "oidc"
|
||||||
|
|
||||||
def _render_error(
|
self._sso_handler = hs.get_sso_handler()
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""Render the error template and respond to the request with it.
|
|
||||||
|
|
||||||
This is used to show errors to the user. The template of this page can
|
|
||||||
be found under `synapse/res/templates/sso_error.html`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The incoming request from the browser.
|
|
||||||
We'll respond with an HTML page describing the error.
|
|
||||||
error: A technical identifier for this error. Those include
|
|
||||||
well-known OAuth2/OIDC error types like invalid_request or
|
|
||||||
access_denied.
|
|
||||||
error_description: A human-readable description of the error.
|
|
||||||
"""
|
|
||||||
html = self._error_template.render(
|
|
||||||
error=error, error_description=error_description
|
|
||||||
)
|
|
||||||
respond_with_html(request, 400, html)
|
|
||||||
|
|
||||||
def _validate_metadata(self):
|
def _validate_metadata(self):
|
||||||
"""Verifies the provider metadata.
|
"""Verifies the provider metadata.
|
||||||
|
@ -571,7 +544,7 @@ class OidcHandler:
|
||||||
|
|
||||||
Since we might want to display OIDC-related errors in a user-friendly
|
Since we might want to display OIDC-related errors in a user-friendly
|
||||||
way, we don't raise SynapseError from here. Instead, we call
|
way, we don't raise SynapseError from here. Instead, we call
|
||||||
``self._render_error`` which displays an HTML page for the error.
|
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||||
|
|
||||||
Most of the OpenID Connect logic happens here:
|
Most of the OpenID Connect logic happens here:
|
||||||
|
|
||||||
|
@ -609,7 +582,7 @@ class OidcHandler:
|
||||||
if error != "access_denied":
|
if error != "access_denied":
|
||||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||||
|
|
||||||
self._render_error(request, error, description)
|
self._sso_handler.render_error(request, error, description)
|
||||||
return
|
return
|
||||||
|
|
||||||
# otherwise, it is presumably a successful response. see:
|
# otherwise, it is presumably a successful response. see:
|
||||||
|
@ -619,7 +592,9 @@ class OidcHandler:
|
||||||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||||
if session is None:
|
if session is None:
|
||||||
logger.info("No session cookie found")
|
logger.info("No session cookie found")
|
||||||
self._render_error(request, "missing_session", "No session cookie found")
|
self._sso_handler.render_error(
|
||||||
|
request, "missing_session", "No session cookie found"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Remove the cookie. There is a good chance that if the callback failed
|
# Remove the cookie. There is a good chance that if the callback failed
|
||||||
|
@ -637,7 +612,9 @@ class OidcHandler:
|
||||||
# Check for the state query parameter
|
# Check for the state query parameter
|
||||||
if b"state" not in request.args:
|
if b"state" not in request.args:
|
||||||
logger.info("State parameter is missing")
|
logger.info("State parameter is missing")
|
||||||
self._render_error(request, "invalid_request", "State parameter is missing")
|
self._sso_handler.render_error(
|
||||||
|
request, "invalid_request", "State parameter is missing"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
state = request.args[b"state"][0].decode()
|
state = request.args[b"state"][0].decode()
|
||||||
|
@ -651,17 +628,19 @@ class OidcHandler:
|
||||||
) = self._verify_oidc_session_token(session, state)
|
) = self._verify_oidc_session_token(session, state)
|
||||||
except MacaroonDeserializationException as e:
|
except MacaroonDeserializationException as e:
|
||||||
logger.exception("Invalid session")
|
logger.exception("Invalid session")
|
||||||
self._render_error(request, "invalid_session", str(e))
|
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||||
return
|
return
|
||||||
except MacaroonInvalidSignatureException as e:
|
except MacaroonInvalidSignatureException as e:
|
||||||
logger.exception("Could not verify session")
|
logger.exception("Could not verify session")
|
||||||
self._render_error(request, "mismatching_session", str(e))
|
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Exchange the code with the provider
|
# Exchange the code with the provider
|
||||||
if b"code" not in request.args:
|
if b"code" not in request.args:
|
||||||
logger.info("Code parameter is missing")
|
logger.info("Code parameter is missing")
|
||||||
self._render_error(request, "invalid_request", "Code parameter is missing")
|
self._sso_handler.render_error(
|
||||||
|
request, "invalid_request", "Code parameter is missing"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("Exchanging code")
|
logger.debug("Exchanging code")
|
||||||
|
@ -670,7 +649,7 @@ class OidcHandler:
|
||||||
token = await self._exchange_code(code)
|
token = await self._exchange_code(code)
|
||||||
except OidcError as e:
|
except OidcError as e:
|
||||||
logger.exception("Could not exchange code")
|
logger.exception("Could not exchange code")
|
||||||
self._render_error(request, e.error, e.error_description)
|
self._sso_handler.render_error(request, e.error, e.error_description)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("Successfully obtained OAuth2 access token")
|
logger.debug("Successfully obtained OAuth2 access token")
|
||||||
|
@ -683,7 +662,7 @@ class OidcHandler:
|
||||||
userinfo = await self._fetch_userinfo(token)
|
userinfo = await self._fetch_userinfo(token)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Could not fetch userinfo")
|
logger.exception("Could not fetch userinfo")
|
||||||
self._render_error(request, "fetch_error", str(e))
|
self._sso_handler.render_error(request, "fetch_error", str(e))
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.debug("Extracting userinfo from id_token")
|
logger.debug("Extracting userinfo from id_token")
|
||||||
|
@ -691,7 +670,7 @@ class OidcHandler:
|
||||||
userinfo = await self._parse_id_token(token, nonce=nonce)
|
userinfo = await self._parse_id_token(token, nonce=nonce)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Invalid id_token")
|
logger.exception("Invalid id_token")
|
||||||
self._render_error(request, "invalid_token", str(e))
|
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Pull out the user-agent and IP from the request.
|
# Pull out the user-agent and IP from the request.
|
||||||
|
@ -705,7 +684,7 @@ class OidcHandler:
|
||||||
)
|
)
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
self._render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Mapping providers might not have get_extra_attributes: only call this
|
# Mapping providers might not have get_extra_attributes: only call this
|
||||||
|
@ -770,7 +749,7 @@ class OidcHandler:
|
||||||
macaroon.add_first_party_caveat(
|
macaroon.add_first_party_caveat(
|
||||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
||||||
)
|
)
|
||||||
now = self._clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
|
||||||
|
@ -845,7 +824,7 @@ class OidcHandler:
|
||||||
if not caveat.startswith(prefix):
|
if not caveat.startswith(prefix):
|
||||||
return False
|
return False
|
||||||
expiry = int(caveat[len(prefix) :])
|
expiry = int(caveat[len(prefix) :])
|
||||||
now = self._clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
async def _map_userinfo_to_user(
|
async def _map_userinfo_to_user(
|
||||||
|
@ -885,20 +864,14 @@ class OidcHandler:
|
||||||
# to be strings.
|
# to be strings.
|
||||||
remote_user_id = str(remote_user_id)
|
remote_user_id = str(remote_user_id)
|
||||||
|
|
||||||
logger.info(
|
# first of all, check if we already have a mapping for this user
|
||||||
"Looking for existing mapping for user %s:%s",
|
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||||
self._auth_provider_id,
|
|
||||||
remote_user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
|
||||||
self._auth_provider_id, remote_user_id,
|
self._auth_provider_id, remote_user_id,
|
||||||
)
|
)
|
||||||
|
if previously_registered_user_id:
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
if registered_user_id is not None:
|
# Otherwise, generate a new user.
|
||||||
logger.info("Found existing mapping %s", registered_user_id)
|
|
||||||
return registered_user_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||||
userinfo, token
|
userinfo, token
|
||||||
|
@ -917,8 +890,8 @@ class OidcHandler:
|
||||||
|
|
||||||
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
||||||
|
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
user_id = UserID(localpart, self.server_name).to_string()
|
||||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if users:
|
if users:
|
||||||
if self._allow_existing_users:
|
if self._allow_existing_users:
|
||||||
if len(users) == 1:
|
if len(users) == 1:
|
||||||
|
@ -942,7 +915,8 @@ class OidcHandler:
|
||||||
default_display_name=attributes["display_name"],
|
default_display_name=attributes["display_name"],
|
||||||
user_agent_ips=(user_agent, ip_address),
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
await self._datastore.record_user_external_id(
|
|
||||||
|
await self.store.record_user_external_id(
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
|
@ -24,7 +24,8 @@ from saml2.client import Saml2Client
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.saml2_config import SamlAttributeRequirement
|
from synapse.config.saml2_config import SamlAttributeRequirement
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.handlers._base import BaseHandler
|
||||||
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
|
@ -42,10 +43,6 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MappingException(Exception):
|
|
||||||
"""Used to catch errors when mapping the SAML2 response to a user."""
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class Saml2SessionData:
|
class Saml2SessionData:
|
||||||
"""Data we track about SAML2 sessions"""
|
"""Data we track about SAML2 sessions"""
|
||||||
|
@ -57,17 +54,13 @@ class Saml2SessionData:
|
||||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||||
|
|
||||||
|
|
||||||
class SamlHandler:
|
class SamlHandler(BaseHandler):
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||||
self.hs = hs
|
super().__init__(hs)
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||||
self._auth = hs.get_auth()
|
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
self._clock = hs.get_clock()
|
|
||||||
self._datastore = hs.get_datastore()
|
|
||||||
self._hostname = hs.hostname
|
|
||||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||||
self._grandfathered_mxid_source_attribute = (
|
self._grandfathered_mxid_source_attribute = (
|
||||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||||
|
@ -88,26 +81,9 @@ class SamlHandler:
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
||||||
# a lock on the mappings
|
# a lock on the mappings
|
||||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
|
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
|
||||||
|
|
||||||
def _render_error(
|
self._sso_handler = hs.get_sso_handler()
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""Render the error template and respond to the request with it.
|
|
||||||
|
|
||||||
This is used to show errors to the user. The template of this page can
|
|
||||||
be found under `synapse/res/templates/sso_error.html`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The incoming request from the browser.
|
|
||||||
We'll respond with an HTML page describing the error.
|
|
||||||
error: A technical identifier for this error.
|
|
||||||
error_description: A human-readable description of the error.
|
|
||||||
"""
|
|
||||||
html = self._error_template.render(
|
|
||||||
error=error, error_description=error_description
|
|
||||||
)
|
|
||||||
respond_with_html(request, 400, html)
|
|
||||||
|
|
||||||
def handle_redirect_request(
|
def handle_redirect_request(
|
||||||
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
|
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
|
||||||
|
@ -130,7 +106,7 @@ class SamlHandler:
|
||||||
# Since SAML sessions timeout it is useful to log when they were created.
|
# Since SAML sessions timeout it is useful to log when they were created.
|
||||||
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
||||||
)
|
)
|
||||||
|
@ -171,12 +147,12 @@ class SamlHandler:
|
||||||
# in the (user-visible) exception message, so let's log the exception here
|
# in the (user-visible) exception message, so let's log the exception here
|
||||||
# so we can track down the session IDs later.
|
# so we can track down the session IDs later.
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
self._render_error(
|
self._sso_handler.render_error(
|
||||||
request, "unsolicited_response", "Unexpected SAML2 login."
|
request, "unsolicited_response", "Unexpected SAML2 login."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._render_error(
|
self._sso_handler.render_error(
|
||||||
request,
|
request,
|
||||||
"invalid_response",
|
"invalid_response",
|
||||||
"Unable to parse SAML2 response: %s." % (e,),
|
"Unable to parse SAML2 response: %s." % (e,),
|
||||||
|
@ -184,7 +160,7 @@ class SamlHandler:
|
||||||
return
|
return
|
||||||
|
|
||||||
if saml2_auth.not_signed:
|
if saml2_auth.not_signed:
|
||||||
self._render_error(
|
self._sso_handler.render_error(
|
||||||
request, "unsigned_respond", "SAML2 response was not signed."
|
request, "unsigned_respond", "SAML2 response was not signed."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -210,7 +186,7 @@ class SamlHandler:
|
||||||
# attributes.
|
# attributes.
|
||||||
for requirement in self._saml2_attribute_requirements:
|
for requirement in self._saml2_attribute_requirements:
|
||||||
if not _check_attribute_requirement(saml2_auth.ava, requirement):
|
if not _check_attribute_requirement(saml2_auth.ava, requirement):
|
||||||
self._render_error(
|
self._sso_handler.render_error(
|
||||||
request, "unauthorised", "You are not authorised to log in here."
|
request, "unauthorised", "You are not authorised to log in here."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -226,7 +202,7 @@ class SamlHandler:
|
||||||
)
|
)
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
self._render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Complete the interactive auth session or the login.
|
# Complete the interactive auth session or the login.
|
||||||
|
@ -274,17 +250,11 @@ class SamlHandler:
|
||||||
|
|
||||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||||
# first of all, check if we already have a mapping for this user
|
# first of all, check if we already have a mapping for this user
|
||||||
logger.info(
|
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||||
"Looking for existing mapping for user %s:%s",
|
self._auth_provider_id, remote_user_id,
|
||||||
self._auth_provider_id,
|
|
||||||
remote_user_id,
|
|
||||||
)
|
)
|
||||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
if previously_registered_user_id:
|
||||||
self._auth_provider_id, remote_user_id
|
return previously_registered_user_id
|
||||||
)
|
|
||||||
if registered_user_id is not None:
|
|
||||||
logger.info("Found existing mapping %s", registered_user_id)
|
|
||||||
return registered_user_id
|
|
||||||
|
|
||||||
# backwards-compatibility hack: see if there is an existing user with a
|
# backwards-compatibility hack: see if there is an existing user with a
|
||||||
# suitable mapping from the uid
|
# suitable mapping from the uid
|
||||||
|
@ -294,7 +264,7 @@ class SamlHandler:
|
||||||
):
|
):
|
||||||
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
||||||
user_id = UserID(
|
user_id = UserID(
|
||||||
map_username_to_mxid_localpart(attrval), self._hostname
|
map_username_to_mxid_localpart(attrval), self.server_name
|
||||||
).to_string()
|
).to_string()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Looking for existing account based on mapped %s %s",
|
"Looking for existing account based on mapped %s %s",
|
||||||
|
@ -302,11 +272,11 @@ class SamlHandler:
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if users:
|
if users:
|
||||||
registered_user_id = list(users.keys())[0]
|
registered_user_id = list(users.keys())[0]
|
||||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||||
await self._datastore.record_user_external_id(
|
await self.store.record_user_external_id(
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
self._auth_provider_id, remote_user_id, registered_user_id
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
@ -335,8 +305,8 @@ class SamlHandler:
|
||||||
emails = attribute_dict.get("emails", [])
|
emails = attribute_dict.get("emails", [])
|
||||||
|
|
||||||
# Check if this mxid already exists
|
# Check if this mxid already exists
|
||||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
if not await self.store.get_users_by_id_case_insensitive(
|
||||||
UserID(localpart, self._hostname).to_string()
|
UserID(localpart, self.server_name).to_string()
|
||||||
):
|
):
|
||||||
# This mxid is free
|
# This mxid is free
|
||||||
break
|
break
|
||||||
|
@ -348,7 +318,6 @@ class SamlHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Mapped SAML user to local part %s", localpart)
|
logger.info("Mapped SAML user to local part %s", localpart)
|
||||||
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
|
@ -356,13 +325,13 @@ class SamlHandler:
|
||||||
user_agent_ips=(user_agent, ip_address),
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._datastore.record_user_external_id(
|
await self.store.record_user_external_id(
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
self._auth_provider_id, remote_user_id, registered_user_id
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||||
to_expire = set()
|
to_expire = set()
|
||||||
for reqid, data in self._outstanding_requests_dict.items():
|
for reqid, data in self._outstanding_requests_dict.items():
|
||||||
if data.creation_time < expire_before:
|
if data.creation_time < expire_before:
|
||||||
|
|
90
synapse/handlers/sso.py
Normal file
90
synapse/handlers/sso.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from synapse.handlers._base import BaseHandler
|
||||||
|
from synapse.http.server import respond_with_html
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MappingException(Exception):
|
||||||
|
"""Used to catch errors when mapping the UserInfo object
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SsoHandler(BaseHandler):
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
self._error_template = hs.config.sso_error_template
|
||||||
|
|
||||||
|
def render_error(
|
||||||
|
self, request, error: str, error_description: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Renders the error template and responds with it.
|
||||||
|
|
||||||
|
This is used to show errors to the user. The template of this page can
|
||||||
|
be found under `synapse/res/templates/sso_error.html`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming request from the browser.
|
||||||
|
We'll respond with an HTML page describing the error.
|
||||||
|
error: A technical identifier for this error.
|
||||||
|
error_description: A human-readable description of the error.
|
||||||
|
"""
|
||||||
|
html = self._error_template.render(
|
||||||
|
error=error, error_description=error_description
|
||||||
|
)
|
||||||
|
respond_with_html(request, 400, html)
|
||||||
|
|
||||||
|
async def get_sso_user_by_remote_user_id(
|
||||||
|
self, auth_provider_id: str, remote_user_id: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Maps the user ID of a remote IdP to a mxid for a previously seen user.
|
||||||
|
|
||||||
|
If the user has not been seen yet, this will return None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
|
"oidc" or "saml".
|
||||||
|
remote_user_id: The user ID according to the remote IdP. This might
|
||||||
|
be an e-mail address, a GUID, or some other form. It must be
|
||||||
|
unique and immutable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The mxid of a previously seen user.
|
||||||
|
"""
|
||||||
|
# Check if we already have a mapping for this user.
|
||||||
|
logger.info(
|
||||||
|
"Looking for existing mapping for user %s:%s",
|
||||||
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
)
|
||||||
|
previously_registered_user_id = await self.store.get_user_by_external_id(
|
||||||
|
auth_provider_id, remote_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# A match was found, return the user ID.
|
||||||
|
if previously_registered_user_id is not None:
|
||||||
|
logger.info("Found existing mapping %s", previously_registered_user_id)
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
# No match.
|
||||||
|
return None
|
|
@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
|
||||||
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||||
from synapse.handlers.search import SearchHandler
|
from synapse.handlers.search import SearchHandler
|
||||||
from synapse.handlers.set_password import SetPasswordHandler
|
from synapse.handlers.set_password import SetPasswordHandler
|
||||||
|
from synapse.handlers.sso import SsoHandler
|
||||||
from synapse.handlers.stats import StatsHandler
|
from synapse.handlers.stats import StatsHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
||||||
|
@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
else:
|
else:
|
||||||
return FollowerTypingHandler(self)
|
return FollowerTypingHandler(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_sso_handler(self) -> SsoHandler:
|
||||||
|
return SsoHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_sync_handler(self) -> SyncHandler:
|
def get_sync_handler(self) -> SyncHandler:
|
||||||
return SyncHandler(self)
|
return SyncHandler(self)
|
||||||
|
|
|
@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.handler = OidcHandler(hs)
|
self.handler = OidcHandler(hs)
|
||||||
|
# Mock the render error method.
|
||||||
|
self.render_error = Mock(return_value=None)
|
||||||
|
self.handler._sso_handler.render_error = self.render_error
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
return patch.dict(self.handler._provider_metadata, values)
|
return patch.dict(self.handler._provider_metadata, values)
|
||||||
|
|
||||||
def assertRenderedError(self, error, error_description=None):
|
def assertRenderedError(self, error, error_description=None):
|
||||||
args = self.handler._render_error.call_args[0]
|
args = self.render_error.call_args[0]
|
||||||
self.assertEqual(args[1], error)
|
self.assertEqual(args[1], error)
|
||||||
if error_description is not None:
|
if error_description is not None:
|
||||||
self.assertEqual(args[2], error_description)
|
self.assertEqual(args[2], error_description)
|
||||||
# Reset the render_error mock
|
# Reset the render_error mock
|
||||||
self.handler._render_error.reset_mock()
|
self.render_error.reset_mock()
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||||
|
@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def test_callback_error(self):
|
def test_callback_error(self):
|
||||||
"""Errors from the provider returned in the callback are displayed."""
|
"""Errors from the provider returned in the callback are displayed."""
|
||||||
self.handler._render_error = Mock()
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"error"] = [b"invalid_client"]
|
request.args[b"error"] = [b"invalid_client"]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"preferred_username": "bar",
|
"preferred_username": "bar",
|
||||||
}
|
}
|
||||||
user_id = "@foo:domain.org"
|
user_id = "@foo:domain.org"
|
||||||
self.handler._render_error = Mock(return_value=None)
|
|
||||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||||
|
@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
userinfo, token, user_agent, ip_address
|
userinfo, token, user_agent, ip_address
|
||||||
)
|
)
|
||||||
self.handler._fetch_userinfo.assert_not_called()
|
self.handler._fetch_userinfo.assert_not_called()
|
||||||
self.handler._render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle mapping errors
|
# Handle mapping errors
|
||||||
self.handler._map_userinfo_to_user = simple_async_mock(
|
self.handler._map_userinfo_to_user = simple_async_mock(
|
||||||
|
@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
userinfo, token, user_agent, ip_address
|
userinfo, token, user_agent, ip_address
|
||||||
)
|
)
|
||||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
self.handler._fetch_userinfo.assert_called_once_with(token)
|
||||||
self.handler._render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle userinfo fetching error
|
# Handle userinfo fetching error
|
||||||
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||||
|
@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def test_callback_session(self):
|
def test_callback_session(self):
|
||||||
"""The callback verifies the session presence and validity"""
|
"""The callback verifies the session presence and validity"""
|
||||||
self.handler._render_error = Mock(return_value=None)
|
|
||||||
request = Mock(spec=["args", "getCookie", "addCookie"])
|
request = Mock(spec=["args", "getCookie", "addCookie"])
|
||||||
|
|
||||||
# Missing cookie
|
# Missing cookie
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue