Support OIDC backchannel logouts (#11414)

If configured an OIDC IdP can log a user's session out of
Synapse when they log out of the identity provider.

The IdP sends a request directly to Synapse (and must be
configured with an endpoint) when a user logs out.
This commit is contained in:
Quentin Gliech 2022-10-31 18:07:30 +01:00 committed by GitHub
parent 15bdb0da52
commit cc3a52b33d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 960 additions and 66 deletions

View file

@ -12,14 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import binascii
import inspect
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
from urllib.parse import urlencode, urlparse
import attr
import unpaddedbase64
from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt
from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, UserInfo
@ -35,9 +49,12 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@ -88,6 +105,8 @@ class Token(TypedDict):
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
C = TypeVar("C")
#: A JWK Set, as per RFC7517 sec 5.
class JWKS(TypedDict):
@ -247,6 +266,80 @@ class OidcHandler:
await oidc_provider.handle_oidc_callback(request, session_data, code)
async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
This extracts the logout_token from the request and tries to figure out
which OpenID Provider it is comming from. This works by matching the iss claim
with the issuer and the aud claim with the client_id.
Since at this point we don't know who signed the JWT, we can't just
decode it using authlib since it will always verifies the signature. We
have to decode it manually without validating the signature. The actual JWT
verification is done in the `OidcProvider.handler_backchannel_logout` method,
once we figured out which provider sent the request.
Args:
request: the incoming request from the browser.
"""
logout_token = parse_string(request, "logout_token")
if logout_token is None:
raise SynapseError(400, "Missing logout_token in request")
# A JWT looks like this:
# header.payload.signature
# where all parts are encoded with urlsafe base64.
# The aud and iss claims we care about are in the payload part, which
# is a JSON object.
try:
# By destructuring the list after splitting, we ensure that we have
# exactly 3 segments
_, payload, _ = logout_token.split(".")
except ValueError:
raise SynapseError(400, "Invalid logout_token in request")
try:
payload_bytes = unpaddedbase64.decode_base64(payload)
claims = json_decoder.decode(payload_bytes.decode("utf-8"))
except (json.JSONDecodeError, binascii.Error, UnicodeError):
raise SynapseError(400, "Invalid logout_token payload in request")
try:
# Let's extract the iss and aud claims
iss = claims["iss"]
aud = claims["aud"]
# The aud claim can be either a string or a list of string. Here we
# normalize it as a list of strings.
if isinstance(aud, str):
aud = [aud]
# Check that we have the right types for the aud and the iss claims
if not isinstance(iss, str) or not isinstance(aud, list):
raise TypeError()
for a in aud:
if not isinstance(a, str):
raise TypeError()
# At this point we properly checked both claims types
issuer: str = iss
audience: List[str] = aud
except (TypeError, KeyError):
raise SynapseError(400, "Invalid issuer/audience in logout_token")
# Now that we know the audience and the issuer, we can figure out from
# what provider it is coming from
oidc_provider: Optional[OidcProvider] = None
for provider in self._providers.values():
if provider.issuer == issuer and provider.client_id in audience:
oidc_provider = provider
break
if oidc_provider is None:
raise SynapseError(400, "Could not find the OP that issued this event")
# Ask the provider to handle the logout request.
await oidc_provider.handle_backchannel_logout(request, logout_token)
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint"""
@ -342,6 +435,7 @@ class OidcProvider:
self.idp_brand = provider.idp_brand
self._sso_handler = hs.get_sso_handler()
self._device_handler = hs.get_device_handler()
self._sso_handler.register_identity_provider(self)
@ -400,6 +494,41 @@ class OidcProvider:
# If we're not using userinfo, we need a valid jwks to validate the ID token
m.validate_jwks_uri()
if self._config.backchannel_logout_enabled:
if not m.get("backchannel_logout_supported", False):
logger.warning(
"OIDC Back-Channel Logout is enabled for issuer %r"
"but it does not advertise support for it",
self.issuer,
)
elif not m.get("backchannel_logout_session_supported", False):
logger.warning(
"OIDC Back-Channel Logout is enabled and supported "
"by issuer %r but it might not send a session ID with "
"logout tokens, which is required for the logouts to work",
self.issuer,
)
if not self._config.backchannel_logout_ignore_sub:
# If OIDC backchannel logouts are enabled, the provider mapping provider
# should use the `sub` claim. We verify that by mapping a dumb user and
# see if we get back the sub claim
user = UserInfo({"sub": "thisisasubject"})
try:
subject = self._user_mapping_provider.get_remote_user_id(user)
if subject != user["sub"]:
raise ValueError("Unexpected subject")
except Exception:
logger.warning(
f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
"but it looks like the configured `user_mapping_provider` "
"does not use the `sub` claim as subject. If it is the case, "
"and you want Synapse to ignore the `sub` claim in OIDC "
"Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
"to `true` in the issuer config."
)
@property
def _uses_userinfo(self) -> bool:
"""Returns True if the ``userinfo_endpoint`` should be used.
@ -415,6 +544,16 @@ class OidcProvider:
or self._user_profile_method == "userinfo_endpoint"
)
@property
def issuer(self) -> str:
"""The issuer identifying this provider."""
return self._config.issuer
@property
def client_id(self) -> str:
"""The client_id used when interacting with this provider."""
return self._config.client_id
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata.
@ -662,6 +801,59 @@ class OidcProvider:
return UserInfo(resp)
async def _verify_jwt(
self,
alg_values: List[str],
token: str,
claims_cls: Type[C],
claims_options: Optional[dict] = None,
claims_params: Optional[dict] = None,
) -> C:
"""Decode and validate a JWT, re-fetching the JWKS as needed.
Args:
alg_values: list of `alg` values allowed when verifying the JWT.
token: the JWT.
claims_cls: the JWTClaims class to use to validate the claims.
claims_options: dict of options passed to the `claims_cls` constructor.
claims_params: dict of params passed to the `claims_cls` constructor.
Returns:
The decoded claims in the JWT.
"""
jwt = JsonWebToken(alg_values)
logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
claims.validate(
now=self._clock.time(), leeway=120
) # allows 2 min of clock skew
return claims
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.
@ -675,13 +867,13 @@ class OidcProvider:
The decoded claims in the ID token.
"""
id_token = token.get("id_token")
logger.debug("Attempting to decode JWT id_token %r", id_token)
# That has been theoritically been checked by the caller, so even though
# assertion are not enabled in production, it is mainly here to appease mypy
assert id_token is not None
metadata = await self.load_metadata()
claims_params = {
"nonce": nonce,
"client_id": self._client_auth.client_id,
@ -691,38 +883,17 @@ class OidcProvider:
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
claims_options = {"iss": {"values": [metadata["issuer"]]}}
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}}
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(
now=self._clock.time(), leeway=120
) # allows 2 min of clock skew
claims = await self._verify_jwt(
alg_values=alg_values,
token=id_token,
claims_cls=CodeIDToken,
claims_options=claims_options,
claims_params=claims_params,
)
return claims
@ -1043,6 +1214,146 @@ class OidcProvider:
# to be strings.
return str(remote_user_id)
async def handle_backchannel_logout(
self, request: SynapseRequest, logout_token: str
) -> None:
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
The OIDC Provider posts a logout token to this endpoint when a user
session ends. That token is a JWT signed with the same keys as
ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
validate the JWT and figure out what session to end.
Args:
request: The request to respond to
logout_token: The logout token (a JWT) extracted from the request body
"""
# Back-Channel Logout can be disabled in the config, hence this check.
# This is not that important for now since Synapse is registered
# manually to the OP, so not specifying the backchannel-logout URI is
# as effective than disabling it here. It might make more sense if we
# support dynamic registration in Synapse at some point.
if not self._config.backchannel_logout_enabled:
logger.warning(
f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
)
# TODO: this responds with a 400 status code, which is what the OIDC
# Back-Channel Logout spec expects, but spec also suggests answering with
# a JSON object, with the `error` and `error_description` fields set, which
# we are not doing here.
# See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
raise SynapseError(
400, "OpenID Connect Back-Channel Logout is disabled for this provider"
)
metadata = await self.load_metadata()
# As per OIDC Back-Channel Logout 1.0 sec. 2.4:
# A Logout Token MUST be signed and MAY also be encrypted. The same
# keys are used to sign and encrypt Logout Tokens as are used for ID
# Tokens. If the Logout Token is encrypted, it SHOULD replicate the
# iss (issuer) claim in the JWT Header Parameters, as specified in
# Section 5.3 of [JWT].
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
# As per sec. 2.6:
# 3. Validate the iss, aud, and iat Claims in the same way they are
# validated in ID Tokens.
# Which means the audience should contain Synapse's client_id and the
# issuer should be the IdP issuer
claims_options = {
"iss": {"values": [metadata["issuer"]]},
"aud": {"values": [self.client_id]},
}
try:
claims = await self._verify_jwt(
alg_values=alg_values,
token=logout_token,
claims_cls=LogoutToken,
claims_options=claims_options,
)
except JoseError:
logger.exception("Invalid logout_token")
raise SynapseError(400, "Invalid logout_token")
# As per sec. 2.6:
# 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
# or both.
# 5. Verify that the Logout Token contains an events Claim whose
# value is JSON object containing the member name
# http://schemas.openid.net/event/backchannel-logout.
# 6. Verify that the Logout Token does not contain a nonce Claim.
# This is all verified by the LogoutToken claims class, so at this
# point the `sid` claim exists and is a string.
sid: str = claims.get("sid")
# If the `sub` claim was included in the logout token, we check that it matches
# that it matches the right user. We can have cases where the `sub` claim is not
# the ID saved in database, so we let admins disable this check in config.
sub: Optional[str] = claims.get("sub")
expected_user_id: Optional[str] = None
if sub is not None and not self._config.backchannel_logout_ignore_sub:
expected_user_id = await self._store.get_user_by_external_id(
self.idp_id, sub
)
# Invalidate any running user-mapping sessions, in-flight login tokens and
# active devices
await self._sso_handler.revoke_sessions_for_provider_session_id(
auth_provider_id=self.idp_id,
auth_provider_session_id=sid,
expected_user_id=expected_user_id,
)
request.setResponseCode(200)
request.setHeader(b"Cache-Control", b"no-cache, no-store")
request.setHeader(b"Pragma", b"no-cache")
finish_request(request)
class LogoutToken(JWTClaims):
"""
Holds and verify claims of a logout token, as per
https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
"""
REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
"""Validate everything in claims payload."""
super().validate(now, leeway)
self.validate_sid()
self.validate_events()
self.validate_nonce()
def validate_sid(self) -> None:
"""Ensure the sid claim is present"""
sid = self.get("sid")
if not sid:
raise MissingClaimError("sid")
if not isinstance(sid, str):
raise InvalidClaimError("sid")
def validate_nonce(self) -> None:
"""Ensure the nonce claim is absent"""
if "nonce" in self:
raise InvalidClaimError("nonce")
def validate_events(self) -> None:
"""Ensure the events claim is present and with the right value"""
events = self.get("events")
if not events:
raise MissingClaimError("events")
if not isinstance(events, dict):
raise InvalidClaimError("events")
if "http://schemas.openid.net/event/backchannel-logout" not in events:
raise InvalidClaimError("events")
# number of seconds a newly-generated client secret should be valid for
CLIENT_SECRET_VALIDITY_SECONDS = 3600
@ -1112,6 +1423,7 @@ class JwtClientSecret:
logger.info(
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
)
jwt = JsonWebToken(header["alg"])
self._cached_secret = jwt.encode(header, payload, self._key.key)
self._cached_secret_replacement_time = (
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict):
emails: List[str]
C = TypeVar("C")
class OidcMappingProvider(Generic[C]):
"""A mapping provider maps a UserInfo object to user attributes.