mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-22 09:04:05 -04:00
Save the OIDC session ID (sid) with the device on login (#11482)
As a step towards allowing back-channel logout for OIDC.
This commit is contained in:
parent
8b4b153c9e
commit
a15a893df8
15 changed files with 370 additions and 65 deletions
|
@ -23,7 +23,7 @@ from authlib.common.security import generate_token
|
|||
from authlib.jose import JsonWebToken, jwt
|
||||
from authlib.oauth2.auth import ClientAuth
|
||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
||||
from authlib.oidc.core import CodeIDToken, UserInfo
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||
from jinja2 import Environment, Template
|
||||
from pymacaroons.exceptions import (
|
||||
|
@ -117,7 +117,8 @@ class OidcHandler:
|
|||
for idp_id, p in self._providers.items():
|
||||
try:
|
||||
await p.load_metadata()
|
||||
await p.load_jwks()
|
||||
if not p._uses_userinfo:
|
||||
await p.load_jwks()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Error while initialising OIDC provider %r" % (idp_id,)
|
||||
|
@ -498,10 +499,6 @@ class OidcProvider:
|
|||
return await self._jwks.get()
|
||||
|
||||
async def _load_jwks(self) -> JWKS:
|
||||
if self._uses_userinfo:
|
||||
# We're not using jwt signing, return an empty jwk set
|
||||
return {"keys": []}
|
||||
|
||||
metadata = await self.load_metadata()
|
||||
|
||||
# Load the JWKS using the `jwks_uri` metadata.
|
||||
|
@ -663,7 +660,7 @@ class OidcProvider:
|
|||
|
||||
return UserInfo(resp)
|
||||
|
||||
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
|
||||
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
||||
"""Return an instance of UserInfo from token's ``id_token``.
|
||||
|
||||
Args:
|
||||
|
@ -673,7 +670,7 @@ class OidcProvider:
|
|||
request. This value should match the one inside the token.
|
||||
|
||||
Returns:
|
||||
An object representing the user.
|
||||
The decoded claims in the ID token.
|
||||
"""
|
||||
metadata = await self.load_metadata()
|
||||
claims_params = {
|
||||
|
@ -684,9 +681,6 @@ class OidcProvider:
|
|||
# If we got an `access_token`, there should be an `at_hash` claim
|
||||
# in the `id_token` that we can check against.
|
||||
claims_params["access_token"] = token["access_token"]
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
@ -703,7 +697,7 @@ class OidcProvider:
|
|||
claims = jwt.decode(
|
||||
id_token,
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
@ -713,7 +707,7 @@ class OidcProvider:
|
|||
claims = jwt.decode(
|
||||
id_token,
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
@ -721,7 +715,8 @@ class OidcProvider:
|
|||
logger.debug("Decoded id_token JWT %r; validating", claims)
|
||||
|
||||
claims.validate(leeway=120) # allows 2 min of clock skew
|
||||
return UserInfo(claims)
|
||||
|
||||
return claims
|
||||
|
||||
async def handle_redirect_request(
|
||||
self,
|
||||
|
@ -837,8 +832,22 @@ class OidcProvider:
|
|||
|
||||
logger.debug("Successfully obtained OAuth2 token data: %r", token)
|
||||
|
||||
# Now that we have a token, get the userinfo, either by decoding the
|
||||
# `id_token` or by fetching the `userinfo_endpoint`.
|
||||
# If there is an id_token, it should be validated, regardless of the
|
||||
# userinfo endpoint is used or not.
|
||||
if token.get("id_token") is not None:
|
||||
try:
|
||||
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||
sid = id_token.get("sid")
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
else:
|
||||
id_token = None
|
||||
sid = None
|
||||
|
||||
# Now that we have a token, get the userinfo either from the `id_token`
|
||||
# claims or by fetching the `userinfo_endpoint`.
|
||||
if self._uses_userinfo:
|
||||
try:
|
||||
userinfo = await self._fetch_userinfo(token)
|
||||
|
@ -846,13 +855,14 @@ class OidcProvider:
|
|||
logger.exception("Could not fetch userinfo")
|
||||
self._sso_handler.render_error(request, "fetch_error", str(e))
|
||||
return
|
||||
elif id_token is not None:
|
||||
userinfo = UserInfo(id_token)
|
||||
else:
|
||||
try:
|
||||
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
logger.error("Missing id_token in token response")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_token", "Missing id_token in token response"
|
||||
)
|
||||
return
|
||||
|
||||
# first check if we're doing a UIA
|
||||
if session_data.ui_auth_session_id:
|
||||
|
@ -884,7 +894,7 @@ class OidcProvider:
|
|||
# Call the mapper to register/login the user
|
||||
try:
|
||||
await self._complete_oidc_login(
|
||||
userinfo, token, request, session_data.client_redirect_url
|
||||
userinfo, token, request, session_data.client_redirect_url, sid
|
||||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
|
@ -896,6 +906,7 @@ class OidcProvider:
|
|||
token: Token,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
sid: Optional[str],
|
||||
) -> None:
|
||||
"""Given a UserInfo response, complete the login flow
|
||||
|
||||
|
@ -1008,6 +1019,7 @@ class OidcProvider:
|
|||
oidc_response_to_user_attributes,
|
||||
grandfather_existing_users,
|
||||
extra_attributes,
|
||||
auth_provider_session_id=sid,
|
||||
)
|
||||
|
||||
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue