Support OIDC backchannel logouts (#11414)

If configured an OIDC IdP can log a user's session out of
Synapse when they log out of the identity provider.

The IdP sends a request directly to Synapse (and must be
configured with an endpoint) when a user logs out.
This commit is contained in:
Quentin Gliech 2022-10-31 18:07:30 +01:00 committed by GitHub
parent 15bdb0da52
commit cc3a52b33d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 960 additions and 66 deletions

View File

@ -0,0 +1 @@
Support back-channel logouts from OpenID Connect providers.

View File

@ -49,6 +49,13 @@ setting in your configuration file.
See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as
the text below for example configurations for specific providers. the text below for example configurations for specific providers.
## OIDC Back-Channel Logout
Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications.
This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session.
This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout`
## Sample configs ## Sample configs
Here are a few configs for providers that should work with Synapse. Here are a few configs for providers that should work with Synapse.
@ -123,6 +130,9 @@ oidc_providers:
[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak.
This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak.
Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
1. Click `Clients` in the sidebar and click `Create` 1. Click `Clients` in the sidebar and click `Create`
@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
| Client Protocol | `openid-connect` | | Client Protocol | `openid-connect` |
| Access Type | `confidential` | | Access Type | `confidential` |
| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` |
| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` |
| Backchannel Logout Session Required (optional) | `On` |
5. Click `Save` 5. Click `Save`
6. On the Credentials tab, update the fields: 6. On the Credentials tab, update the fields:
@ -167,7 +179,9 @@ oidc_providers:
config: config:
localpart_template: "{{ user.preferred_username }}" localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}" display_name_template: "{{ user.name }}"
backchannel_logout_enabled: true # Optional
``` ```
### Auth0 ### Auth0
[Auth0][auth0] is a hosted SaaS IdP solution. [Auth0][auth0] is a hosted SaaS IdP solution.

View File

@ -3021,6 +3021,15 @@ Options for each entry include:
which is set to the claims returned by the UserInfo Endpoint and/or which is set to the claims returned by the UserInfo Endpoint and/or
in the ID Token. in the ID Token.
* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications.
Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`.
Defaults to `false`.
* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the
`sub` claim matches the subject claim received during login. This check can be disabled by setting
this to `true`. Defaults to `false`.
You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`.
It is possible to configure Synapse to only allow logins if certain attributes It is possible to configure Synapse to only allow logins if certain attributes
match particular values in the OIDC userinfo. The requirements can be listed under match particular values in the OIDC userinfo. The requirements can be listed under

View File

@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"userinfo_endpoint": {"type": "string"}, "userinfo_endpoint": {"type": "string"},
"jwks_uri": {"type": "string"}, "jwks_uri": {"type": "string"},
"skip_verification": {"type": "boolean"}, "skip_verification": {"type": "boolean"},
"backchannel_logout_enabled": {"type": "boolean"},
"backchannel_logout_ignore_sub": {"type": "boolean"},
"user_profile_method": { "user_profile_method": {
"type": "string", "type": "string",
"enum": ["auto", "userinfo_endpoint"], "enum": ["auto", "userinfo_endpoint"],
@ -292,6 +294,10 @@ def _parse_oidc_config_dict(
token_endpoint=oidc_config.get("token_endpoint"), token_endpoint=oidc_config.get("token_endpoint"),
userinfo_endpoint=oidc_config.get("userinfo_endpoint"), userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
jwks_uri=oidc_config.get("jwks_uri"), jwks_uri=oidc_config.get("jwks_uri"),
backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False),
backchannel_logout_ignore_sub=oidc_config.get(
"backchannel_logout_ignore_sub", False
),
skip_verification=oidc_config.get("skip_verification", False), skip_verification=oidc_config.get("skip_verification", False),
user_profile_method=oidc_config.get("user_profile_method", "auto"), user_profile_method=oidc_config.get("user_profile_method", "auto"),
allow_existing_users=oidc_config.get("allow_existing_users", False), allow_existing_users=oidc_config.get("allow_existing_users", False),
@ -368,6 +374,12 @@ class OidcProviderConfig:
# "openid" scope is used. # "openid" scope is used.
jwks_uri: Optional[str] jwks_uri: Optional[str]
# Whether Synapse should react to backchannel logouts
backchannel_logout_enabled: bool
# Whether Synapse should ignore the `sub` claim in backchannel logouts or not.
backchannel_logout_ignore_sub: bool
# Whether to skip metadata verification # Whether to skip metadata verification
skip_verification: bool skip_verification: bool

View File

@ -12,14 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import binascii
import inspect import inspect
import json
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
import attr import attr
import unpaddedbase64
from authlib.common.security import generate_token from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
from authlib.oauth2.auth import ClientAuth from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.core import CodeIDToken, UserInfo
@ -35,9 +49,12 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@ -88,6 +105,8 @@ class Token(TypedDict):
#: there is no real point of doing this in our case. #: there is no real point of doing this in our case.
JWK = Dict[str, str] JWK = Dict[str, str]
C = TypeVar("C")
#: A JWK Set, as per RFC7517 sec 5. #: A JWK Set, as per RFC7517 sec 5.
class JWKS(TypedDict): class JWKS(TypedDict):
@ -247,6 +266,80 @@ class OidcHandler:
await oidc_provider.handle_oidc_callback(request, session_data, code) await oidc_provider.handle_oidc_callback(request, session_data, code)
async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
This extracts the logout_token from the request and tries to figure out
which OpenID Provider it is comming from. This works by matching the iss claim
with the issuer and the aud claim with the client_id.
Since at this point we don't know who signed the JWT, we can't just
decode it using authlib since it will always verifies the signature. We
have to decode it manually without validating the signature. The actual JWT
verification is done in the `OidcProvider.handler_backchannel_logout` method,
once we figured out which provider sent the request.
Args:
request: the incoming request from the browser.
"""
logout_token = parse_string(request, "logout_token")
if logout_token is None:
raise SynapseError(400, "Missing logout_token in request")
# A JWT looks like this:
# header.payload.signature
# where all parts are encoded with urlsafe base64.
# The aud and iss claims we care about are in the payload part, which
# is a JSON object.
try:
# By destructuring the list after splitting, we ensure that we have
# exactly 3 segments
_, payload, _ = logout_token.split(".")
except ValueError:
raise SynapseError(400, "Invalid logout_token in request")
try:
payload_bytes = unpaddedbase64.decode_base64(payload)
claims = json_decoder.decode(payload_bytes.decode("utf-8"))
except (json.JSONDecodeError, binascii.Error, UnicodeError):
raise SynapseError(400, "Invalid logout_token payload in request")
try:
# Let's extract the iss and aud claims
iss = claims["iss"]
aud = claims["aud"]
# The aud claim can be either a string or a list of string. Here we
# normalize it as a list of strings.
if isinstance(aud, str):
aud = [aud]
# Check that we have the right types for the aud and the iss claims
if not isinstance(iss, str) or not isinstance(aud, list):
raise TypeError()
for a in aud:
if not isinstance(a, str):
raise TypeError()
# At this point we properly checked both claims types
issuer: str = iss
audience: List[str] = aud
except (TypeError, KeyError):
raise SynapseError(400, "Invalid issuer/audience in logout_token")
# Now that we know the audience and the issuer, we can figure out from
# what provider it is coming from
oidc_provider: Optional[OidcProvider] = None
for provider in self._providers.values():
if provider.issuer == issuer and provider.client_id in audience:
oidc_provider = provider
break
if oidc_provider is None:
raise SynapseError(400, "Could not find the OP that issued this event")
# Ask the provider to handle the logout request.
await oidc_provider.handle_backchannel_logout(request, logout_token)
class OidcError(Exception): class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint""" """Used to catch errors when calling the token_endpoint"""
@ -342,6 +435,7 @@ class OidcProvider:
self.idp_brand = provider.idp_brand self.idp_brand = provider.idp_brand
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._device_handler = hs.get_device_handler()
self._sso_handler.register_identity_provider(self) self._sso_handler.register_identity_provider(self)
@ -400,6 +494,41 @@ class OidcProvider:
# If we're not using userinfo, we need a valid jwks to validate the ID token # If we're not using userinfo, we need a valid jwks to validate the ID token
m.validate_jwks_uri() m.validate_jwks_uri()
if self._config.backchannel_logout_enabled:
if not m.get("backchannel_logout_supported", False):
logger.warning(
"OIDC Back-Channel Logout is enabled for issuer %r"
"but it does not advertise support for it",
self.issuer,
)
elif not m.get("backchannel_logout_session_supported", False):
logger.warning(
"OIDC Back-Channel Logout is enabled and supported "
"by issuer %r but it might not send a session ID with "
"logout tokens, which is required for the logouts to work",
self.issuer,
)
if not self._config.backchannel_logout_ignore_sub:
# If OIDC backchannel logouts are enabled, the provider mapping provider
# should use the `sub` claim. We verify that by mapping a dumb user and
# see if we get back the sub claim
user = UserInfo({"sub": "thisisasubject"})
try:
subject = self._user_mapping_provider.get_remote_user_id(user)
if subject != user["sub"]:
raise ValueError("Unexpected subject")
except Exception:
logger.warning(
f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
"but it looks like the configured `user_mapping_provider` "
"does not use the `sub` claim as subject. If it is the case, "
"and you want Synapse to ignore the `sub` claim in OIDC "
"Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
"to `true` in the issuer config."
)
@property @property
def _uses_userinfo(self) -> bool: def _uses_userinfo(self) -> bool:
"""Returns True if the ``userinfo_endpoint`` should be used. """Returns True if the ``userinfo_endpoint`` should be used.
@ -415,6 +544,16 @@ class OidcProvider:
or self._user_profile_method == "userinfo_endpoint" or self._user_profile_method == "userinfo_endpoint"
) )
@property
def issuer(self) -> str:
"""The issuer identifying this provider."""
return self._config.issuer
@property
def client_id(self) -> str:
"""The client_id used when interacting with this provider."""
return self._config.client_id
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata. """Return the provider metadata.
@ -662,6 +801,59 @@ class OidcProvider:
return UserInfo(resp) return UserInfo(resp)
async def _verify_jwt(
self,
alg_values: List[str],
token: str,
claims_cls: Type[C],
claims_options: Optional[dict] = None,
claims_params: Optional[dict] = None,
) -> C:
"""Decode and validate a JWT, re-fetching the JWKS as needed.
Args:
alg_values: list of `alg` values allowed when verifying the JWT.
token: the JWT.
claims_cls: the JWTClaims class to use to validate the claims.
claims_options: dict of options passed to the `claims_cls` constructor.
claims_params: dict of params passed to the `claims_cls` constructor.
Returns:
The decoded claims in the JWT.
"""
jwt = JsonWebToken(alg_values)
logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
claims.validate(
now=self._clock.time(), leeway=120
) # allows 2 min of clock skew
return claims
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``. """Return an instance of UserInfo from token's ``id_token``.
@ -675,13 +867,13 @@ class OidcProvider:
The decoded claims in the ID token. The decoded claims in the ID token.
""" """
id_token = token.get("id_token") id_token = token.get("id_token")
logger.debug("Attempting to decode JWT id_token %r", id_token)
# That has been theoritically been checked by the caller, so even though # That has been theoritically been checked by the caller, so even though
# assertion are not enabled in production, it is mainly here to appease mypy # assertion are not enabled in production, it is mainly here to appease mypy
assert id_token is not None assert id_token is not None
metadata = await self.load_metadata() metadata = await self.load_metadata()
claims_params = { claims_params = {
"nonce": nonce, "nonce": nonce,
"client_id": self._client_auth.client_id, "client_id": self._client_auth.client_id,
@ -691,38 +883,17 @@ class OidcProvider:
# in the `id_token` that we can check against. # in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"] claims_params["access_token"] = token["access_token"]
claims_options = {"iss": {"values": [metadata["issuer"]]}}
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}} claims = await self._verify_jwt(
alg_values=alg_values,
# Try to decode the keys in cache first, then retry by forcing the keys token=id_token,
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=CodeIDToken, claims_cls=CodeIDToken,
claims_options=claim_options, claims_options=claims_options,
claims_params=claims_params, claims_params=claims_params,
) )
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(
now=self._clock.time(), leeway=120
) # allows 2 min of clock skew
return claims return claims
@ -1043,6 +1214,146 @@ class OidcProvider:
# to be strings. # to be strings.
return str(remote_user_id) return str(remote_user_id)
async def handle_backchannel_logout(
self, request: SynapseRequest, logout_token: str
) -> None:
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
The OIDC Provider posts a logout token to this endpoint when a user
session ends. That token is a JWT signed with the same keys as
ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
validate the JWT and figure out what session to end.
Args:
request: The request to respond to
logout_token: The logout token (a JWT) extracted from the request body
"""
# Back-Channel Logout can be disabled in the config, hence this check.
# This is not that important for now since Synapse is registered
# manually to the OP, so not specifying the backchannel-logout URI is
# as effective than disabling it here. It might make more sense if we
# support dynamic registration in Synapse at some point.
if not self._config.backchannel_logout_enabled:
logger.warning(
f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
)
# TODO: this responds with a 400 status code, which is what the OIDC
# Back-Channel Logout spec expects, but spec also suggests answering with
# a JSON object, with the `error` and `error_description` fields set, which
# we are not doing here.
# See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
raise SynapseError(
400, "OpenID Connect Back-Channel Logout is disabled for this provider"
)
metadata = await self.load_metadata()
# As per OIDC Back-Channel Logout 1.0 sec. 2.4:
# A Logout Token MUST be signed and MAY also be encrypted. The same
# keys are used to sign and encrypt Logout Tokens as are used for ID
# Tokens. If the Logout Token is encrypted, it SHOULD replicate the
# iss (issuer) claim in the JWT Header Parameters, as specified in
# Section 5.3 of [JWT].
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
# As per sec. 2.6:
# 3. Validate the iss, aud, and iat Claims in the same way they are
# validated in ID Tokens.
# Which means the audience should contain Synapse's client_id and the
# issuer should be the IdP issuer
claims_options = {
"iss": {"values": [metadata["issuer"]]},
"aud": {"values": [self.client_id]},
}
try:
claims = await self._verify_jwt(
alg_values=alg_values,
token=logout_token,
claims_cls=LogoutToken,
claims_options=claims_options,
)
except JoseError:
logger.exception("Invalid logout_token")
raise SynapseError(400, "Invalid logout_token")
# As per sec. 2.6:
# 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
# or both.
# 5. Verify that the Logout Token contains an events Claim whose
# value is JSON object containing the member name
# http://schemas.openid.net/event/backchannel-logout.
# 6. Verify that the Logout Token does not contain a nonce Claim.
# This is all verified by the LogoutToken claims class, so at this
# point the `sid` claim exists and is a string.
sid: str = claims.get("sid")
# If the `sub` claim was included in the logout token, we check that it matches
# that it matches the right user. We can have cases where the `sub` claim is not
# the ID saved in database, so we let admins disable this check in config.
sub: Optional[str] = claims.get("sub")
expected_user_id: Optional[str] = None
if sub is not None and not self._config.backchannel_logout_ignore_sub:
expected_user_id = await self._store.get_user_by_external_id(
self.idp_id, sub
)
# Invalidate any running user-mapping sessions, in-flight login tokens and
# active devices
await self._sso_handler.revoke_sessions_for_provider_session_id(
auth_provider_id=self.idp_id,
auth_provider_session_id=sid,
expected_user_id=expected_user_id,
)
request.setResponseCode(200)
request.setHeader(b"Cache-Control", b"no-cache, no-store")
request.setHeader(b"Pragma", b"no-cache")
finish_request(request)
class LogoutToken(JWTClaims):
"""
Holds and verify claims of a logout token, as per
https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
"""
REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
"""Validate everything in claims payload."""
super().validate(now, leeway)
self.validate_sid()
self.validate_events()
self.validate_nonce()
def validate_sid(self) -> None:
"""Ensure the sid claim is present"""
sid = self.get("sid")
if not sid:
raise MissingClaimError("sid")
if not isinstance(sid, str):
raise InvalidClaimError("sid")
def validate_nonce(self) -> None:
"""Ensure the nonce claim is absent"""
if "nonce" in self:
raise InvalidClaimError("nonce")
def validate_events(self) -> None:
"""Ensure the events claim is present and with the right value"""
events = self.get("events")
if not events:
raise MissingClaimError("events")
if not isinstance(events, dict):
raise InvalidClaimError("events")
if "http://schemas.openid.net/event/backchannel-logout" not in events:
raise InvalidClaimError("events")
# number of seconds a newly-generated client secret should be valid for # number of seconds a newly-generated client secret should be valid for
CLIENT_SECRET_VALIDITY_SECONDS = 3600 CLIENT_SECRET_VALIDITY_SECONDS = 3600
@ -1112,6 +1423,7 @@ class JwtClientSecret:
logger.info( logger.info(
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
) )
jwt = JsonWebToken(header["alg"])
self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret = jwt.encode(header, payload, self._key.key)
self._cached_secret_replacement_time = ( self._cached_secret_replacement_time = (
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict):
emails: List[str] emails: List[str]
C = TypeVar("C")
class OidcMappingProvider(Generic[C]): class OidcMappingProvider(Generic[C]):
"""A mapping provider maps a UserInfo object to user attributes. """A mapping provider maps a UserInfo object to user attributes.

View File

@ -191,6 +191,7 @@ class SsoHandler:
self._server_name = hs.hostname self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._error_template = hs.config.sso.sso_error_template self._error_template = hs.config.sso.sso_error_template
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler() self._profile_handler = hs.get_profile_handler()
@ -1026,6 +1027,76 @@ class SsoHandler:
return True return True
async def revoke_sessions_for_provider_session_id(
self,
auth_provider_id: str,
auth_provider_session_id: str,
expected_user_id: Optional[str] = None,
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
auth_provider_session_id: The session ID from the provider to logout
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""
# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():
if (
session.auth_provider_id == auth_provider_id
and session.auth_provider_session_id == auth_provider_session_id
):
to_delete.append(session_id)
for session_id in to_delete:
logger.info("Revoking mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
# Invalidate any in-flight login tokens
await self._store.invalidate_login_tokens_by_session_id(
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
# Fetch any device(s) in the store associated with the session ID.
devices = await self._store.get_devices_by_auth_provider_session_id(
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
for device in devices:
user_id = device["user_id"]
device_id = device["device_id"]
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
logger.error(
"Received a logout notification from SSO provider "
f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
f"a session ID ({auth_provider_session_id!r}) which belongs to "
f"{user_id!r}. This may happen when the SSO provider user mapper "
"uses something else than the standard attribute as mapping ID. "
"For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
"in the provider config if that is the case."
)
continue
logger.info(
"Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
user_id,
device_id,
auth_provider_id,
auth_provider_session_id,
)
await self._device_handler.delete_devices(user_id, [device_id])
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie """Extract the session ID from the cookie

View File

@ -17,6 +17,9 @@ from typing import TYPE_CHECKING
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.backchannel_logout_resource import (
OIDCBackchannelLogoutResource,
)
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
if TYPE_CHECKING: if TYPE_CHECKING:
@ -29,6 +32,7 @@ class OIDCResource(Resource):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs)) self.putChild(b"callback", OIDCCallbackResource(hs))
self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs))
__all__ = ["OIDCResource"] __all__ = ["OIDCResource"]

View File

@ -0,0 +1,35 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeJsonResource
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class OIDCBackchannelLogoutResource(DirectServeJsonResource):
isLeaf = 1
def __init__(self, hs: "HomeServer"):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_backchannel_logout(request)

View File

@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
self._clock.time_msec(), self._clock.time_msec(),
) )
async def invalidate_login_tokens_by_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> None:
"""Invalidate login tokens with the given IdP session ID.
Args:
auth_provider_id: The SSO Identity Provider that the user authenticated with
to get this token
auth_provider_session_id: The session ID advertised by the SSO Identity
Provider
"""
await self.db_pool.simple_update(
table="login_tokens",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
updatevalues={"used_ts": self._clock.time_msec()},
desc="invalidate_login_tokens_by_session_id",
)
@cached() @cached()
async def is_guest(self, user_id: str) -> bool: async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -21,7 +22,7 @@ from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes, SynapseError
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
@ -32,8 +33,8 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
from tests.server import FakeChannel from tests.server import FakeChannel, make_request
from tests.unittest import override_config, skip_unless from tests.unittest import override_config, skip_unless
@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
{"refresh_token": refresh_token}, {"refresh_token": refresh_token},
) )
def is_access_token_valid(self, access_token: str) -> bool:
"""
Checks whether an access token is valid, returning whether it is or not.
"""
code = self.make_request(
"GET", "/_matrix/client/v3/account/whoami", access_token=access_token
).code
# Either 200 or 401 is what we get back; anything else is a bug.
assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
return code == HTTPStatus.OK
def test_login_issue_refresh_token(self) -> None: def test_login_issue_refresh_token(self) -> None:
""" """
A login response should include a refresh_token only if asked. A login response should include a refresh_token only if asked.
@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.reactor.advance(59.0) self.reactor.advance(59.0)
# Both tokens should still be valid. # Both tokens should still be valid.
self.assertTrue(self.is_access_token_valid(refreshable_access_token)) self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 61 s (just past 1 minute, the time of expiry) # Advance to 61 s (just past 1 minute, the time of expiry)
self.reactor.advance(2.0) self.reactor.advance(2.0)
# Only the non-refreshable token is still valid. # Only the non-refreshable token is still valid.
self.assertFalse(self.is_access_token_valid(refreshable_access_token)) self.helper.whoami(
self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 599 s (just shy of 10 minutes, the time of expiry) # Advance to 599 s (just shy of 10 minutes, the time of expiry)
self.reactor.advance(599.0 - 61.0) self.reactor.advance(599.0 - 61.0)
# It's still the case that only the non-refreshable token is still valid. # It's still the case that only the non-refreshable token is still valid.
self.assertFalse(self.is_access_token_valid(refreshable_access_token)) self.helper.whoami(
self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 601 s (just past 10 minutes, the time of expiry) # Advance to 601 s (just past 10 minutes, the time of expiry)
self.reactor.advance(2.0) self.reactor.advance(2.0)
# Now neither token is valid. # Now neither token is valid.
self.assertFalse(self.is_access_token_valid(refreshable_access_token)) self.helper.whoami(
self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
self.helper.whoami(
nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
@override_config( @override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# and no refresh token # and no refresh token
self.assertEqual(_table_length("access_tokens"), 0) self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0) self.assertEqual(_table_length("refresh_tokens"), 0)
def oidc_config(
id: str, with_localpart_template: bool, **kwargs: Any
) -> Dict[str, Any]:
"""Sample OIDC provider config used in backchannel logout tests.
Args:
id: IDP ID for this provider
with_localpart_template: Set to `true` to have a default localpart_template in
the `user_mapping_provider` config and skip the user mapping session
**kwargs: rest of the config
Returns:
A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
the HS config
"""
config: Dict[str, Any] = {
"idp_id": id,
"idp_name": id,
"issuer": TEST_OIDC_ISSUER,
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["openid"],
}
if with_localpart_template:
config["user_mapping_provider"] = {
"config": {"localpart_template": "{{ user.sub }}"}
}
else:
config["user_mapping_provider"] = {"config": {}}
config.update(kwargs)
return config
@skip_unless(HAS_OIDC, "Requires OIDC")
class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
servlets = [
account.register_servlets,
login.register_servlets,
]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
# False, so synapse will see the requested uri as http://..., so using http in
# the public_baseurl stops Synapse trying to redirect to https.
config["public_baseurl"] = "http://synapse.test"
return config
def create_resource_dict(self) -> Dict[str, Resource]:
resource_dict = super().create_resource_dict()
resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict
def submit_logout_token(self, logout_token: str) -> FakeChannel:
return self.make_request(
"POST",
"/_synapse/client/oidc/backchannel_logout",
content=f"logout_token={logout_token}",
content_is_form=True,
)
@override_config(
{
"oidc_providers": [
oidc_config(
id="oidc",
with_localpart_template=True,
backchannel_logout_enabled=True,
)
]
}
)
def test_simple_logout(self) -> None:
"""
Receiving a logout token should logout the user
"""
fake_oidc_server = self.helper.fake_oidc_server()
user = "john"
login_resp, first_grant = self.helper.login_via_oidc(
fake_oidc_server, user, with_sid=True
)
first_access_token: str = login_resp["access_token"]
self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
login_resp, second_grant = self.helper.login_via_oidc(
fake_oidc_server, user, with_sid=True
)
second_access_token: str = login_resp["access_token"]
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
self.assertNotEqual(first_grant.sid, second_grant.sid)
self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
# Logging out of the first session
logout_token = fake_oidc_server.generate_logout_token(first_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
# Logging out of the second session
logout_token = fake_oidc_server.generate_logout_token(second_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
@override_config(
{
"oidc_providers": [
oidc_config(
id="oidc",
with_localpart_template=True,
backchannel_logout_enabled=True,
)
]
}
)
def test_logout_during_login(self) -> None:
"""
It should revoke login tokens when receiving a logout token
"""
fake_oidc_server = self.helper.fake_oidc_server()
user = "john"
# Get an authentication, and logout before submitting the logout token
client_redirect_url = "https://x"
userinfo = {"sub": user}
channel, grant = self.helper.auth_via_oidc(
fake_oidc_server,
userinfo,
client_redirect_url,
with_sid=True,
)
# expect a confirmation page
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# fish the matrix login token out of the body of the confirmation page
m = re.search(
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
channel.text_body,
)
assert m, channel.text_body
login_token = m.group(1)
# Submit a logout
logout_token = fake_oidc_server.generate_logout_token(grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
# Now try to exchange the login token
channel = make_request(
self.hs.get_reactor(),
self.site,
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
# It should have failed
self.assertEqual(channel.code, 403)
@override_config(
{
"oidc_providers": [
oidc_config(
id="oidc",
with_localpart_template=False,
backchannel_logout_enabled=True,
)
]
}
)
def test_logout_during_mapping(self) -> None:
"""
It should stop ongoing user mapping session when receiving a logout token
"""
fake_oidc_server = self.helper.fake_oidc_server()
user = "john"
# Get an authentication, and logout before submitting the logout token
client_redirect_url = "https://x"
userinfo = {"sub": user}
channel, grant = self.helper.auth_via_oidc(
fake_oidc_server,
userinfo,
client_redirect_url,
with_sid=True,
)
# Expect a user mapping page
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
# We should have a user_mapping_session cookie
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
cookies: Dict[str, str] = {}
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
user_mapping_session_id = cookies["username_mapping_session"]
# Getting that session should not raise
session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
self.assertIsNotNone(session)
# Submit a logout
logout_token = fake_oidc_server.generate_logout_token(grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
# Now it should raise
with self.assertRaises(SynapseError):
self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
@override_config(
{
"oidc_providers": [
oidc_config(
id="oidc",
with_localpart_template=True,
backchannel_logout_enabled=False,
)
]
}
)
def test_disabled(self) -> None:
"""
Receiving a logout token should do nothing if it is disabled in the config
"""
fake_oidc_server = self.helper.fake_oidc_server()
user = "john"
login_resp, grant = self.helper.login_via_oidc(
fake_oidc_server, user, with_sid=True
)
access_token: str = login_resp["access_token"]
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
# Logging out shouldn't work
logout_token = fake_oidc_server.generate_logout_token(grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 400)
# And the token should still be valid
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
@override_config(
{
"oidc_providers": [
oidc_config(
id="oidc",
with_localpart_template=True,
backchannel_logout_enabled=True,
)
]
}
)
def test_no_sid(self) -> None:
"""
Receiving a logout token without `sid` during the login should do nothing
"""
fake_oidc_server = self.helper.fake_oidc_server()
user = "john"
login_resp, grant = self.helper.login_via_oidc(
fake_oidc_server, user, with_sid=False
)
access_token: str = login_resp["access_token"]
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
# Logging out shouldn't work
logout_token = fake_oidc_server.generate_logout_token(grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 400)
# And the token should still be valid
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
@override_config(
{
"oidc_providers": [
oidc_config(
"first",
issuer="https://first-issuer.com/",
with_localpart_template=True,
backchannel_logout_enabled=True,
),
oidc_config(
"second",
issuer="https://second-issuer.com/",
with_localpart_template=True,
backchannel_logout_enabled=True,
),
]
}
)
def test_multiple_providers(self) -> None:
"""
It should be able to distinguish login tokens from two different IdPs
"""
first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
second_server = self.helper.fake_oidc_server(
issuer="https://second-issuer.com/"
)
user = "john"
login_resp, first_grant = self.helper.login_via_oidc(
first_server, user, with_sid=True, idp_id="oidc-first"
)
first_access_token: str = login_resp["access_token"]
self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
login_resp, second_grant = self.helper.login_via_oidc(
second_server, user, with_sid=True, idp_id="oidc-second"
)
second_access_token: str = login_resp["access_token"]
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
# `sid` in the fake providers are generated by a counter, so the first grant of
# each provider should give the same SID
self.assertEqual(first_grant.sid, second_grant.sid)
self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
# Logging out of the first session
logout_token = first_server.generate_logout_token(first_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
# Logging out of the second session
logout_token = second_server.generate_logout_token(second_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)

View File

@ -553,6 +553,34 @@ class RestHelper:
return channel.json_body return channel.json_body
def whoami(
self,
access_token: str,
expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
) -> JsonDict:
"""Perform a 'whoami' request, which can be a quick way to check for access
token validity
Args:
access_token: The user token to use during the request
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
"account/whoami",
access_token=access_token,
)
assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
expect_code,
channel.code,
channel.result["body"],
)
return channel.json_body
def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
"""Create a ``FakeOidcServer``. """Create a ``FakeOidcServer``.
@ -572,6 +600,7 @@ class RestHelper:
fake_server: FakeOidcServer, fake_server: FakeOidcServer,
remote_user_id: str, remote_user_id: str,
with_sid: bool = False, with_sid: bool = False,
idp_id: Optional[str] = None,
expected_status: int = 200, expected_status: int = 200,
) -> Tuple[JsonDict, FakeAuthorizationGrant]: ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC """Log in (as a new user) via OIDC
@ -588,7 +617,11 @@ class RestHelper:
client_redirect_url = "https://x" client_redirect_url = "https://x"
userinfo = {"sub": remote_user_id} userinfo = {"sub": remote_user_id}
channel, grant = self.auth_via_oidc( channel, grant = self.auth_via_oidc(
fake_server, userinfo, client_redirect_url, with_sid=with_sid fake_server,
userinfo,
client_redirect_url,
with_sid=with_sid,
idp_id=idp_id,
) )
# expect a confirmation page # expect a confirmation page
@ -623,6 +656,7 @@ class RestHelper:
client_redirect_url: Optional[str] = None, client_redirect_url: Optional[str] = None,
ui_auth_session_id: Optional[str] = None, ui_auth_session_id: Optional[str] = None,
with_sid: bool = False, with_sid: bool = False,
idp_id: Optional[str] = None,
) -> Tuple[FakeChannel, FakeAuthorizationGrant]: ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Perform an OIDC authentication flow via a mock OIDC provider. """Perform an OIDC authentication flow via a mock OIDC provider.
@ -648,6 +682,7 @@ class RestHelper:
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
of the UI auth. of the UI auth.
with_sid: if True, generates a random `sid` (OIDC session ID) with_sid: if True, generates a random `sid` (OIDC session ID)
idp_id: if set, explicitely chooses one specific IDP
Returns: Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint. A FakeChannel containing the result of calling the OIDC callback endpoint.
@ -665,7 +700,9 @@ class RestHelper:
oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
else: else:
# otherwise, hit the login redirect endpoint # otherwise, hit the login redirect endpoint
oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) oauth_uri = self.initiate_sso_login(
client_redirect_url, cookies, idp_id=idp_id
)
# we now have a URI for the OIDC IdP, but we skip that and go straight # we now have a URI for the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state" # back to synapse's OIDC callback resource. However, we do need the "state"
@ -742,7 +779,10 @@ class RestHelper:
return channel, grant return channel, grant
def initiate_sso_login( def initiate_sso_login(
self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] self,
client_redirect_url: Optional[str],
cookies: MutableMapping[str, str],
idp_id: Optional[str] = None,
) -> str: ) -> str:
"""Make a request to the login-via-sso redirect endpoint, and return the target """Make a request to the login-via-sso redirect endpoint, and return the target
@ -753,6 +793,7 @@ class RestHelper:
client_redirect_url: the client redirect URL to pass to the login redirect client_redirect_url: the client redirect URL to pass to the login redirect
endpoint endpoint
cookies: any cookies returned will be added to this dict cookies: any cookies returned will be added to this dict
idp_id: if set, explicitely chooses one specific IDP
Returns: Returns:
the URI that the client gets redirected to (ie, the SSO server) the URI that the client gets redirected to (ie, the SSO server)
@ -761,6 +802,12 @@ class RestHelper:
if client_redirect_url: if client_redirect_url:
params["redirectUrl"] = client_redirect_url params["redirectUrl"] = client_redirect_url
uri = "/_matrix/client/r0/login/sso/redirect"
if idp_id is not None:
uri = f"{uri}/{idp_id}"
uri = f"{uri}?{urllib.parse.urlencode(params)}"
# hit the redirect url (which should redirect back to the redirect url. This # hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to # is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy. # to keep Synapse happy.
@ -768,7 +815,7 @@ class RestHelper:
self.hs.get_reactor(), self.hs.get_reactor(),
self.site, self.site,
"GET", "GET",
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), uri,
) )
assert channel.code == 302 assert channel.code == 302

View File

@ -362,6 +362,12 @@ def make_request(
# Twisted expects to be at the end of the content when parsing the request. # Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END) req.content.seek(0, SEEK_END)
# Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
# bodies if the Content-Length header is missing
req.requestHeaders.addRawHeader(
b"Content-Length", str(len(content)).encode("ascii")
)
if access_token: if access_token:
req.requestHeaders.addRawHeader( req.requestHeaders.addRawHeader(
b"Authorization", b"Bearer " + access_token.encode("ascii") b"Authorization", b"Bearer " + access_token.encode("ascii")

View File

@ -51,6 +51,8 @@ class FakeOidcServer:
get_userinfo_handler: Mock get_userinfo_handler: Mock
post_token_handler: Mock post_token_handler: Mock
sid_counter: int = 0
def __init__(self, clock: Clock, issuer: str): def __init__(self, clock: Clock, issuer: str):
from authlib.jose import ECKey, KeySet from authlib.jose import ECKey, KeySet
@ -146,7 +148,7 @@ class FakeOidcServer:
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
now = self._clock.time() now = int(self._clock.time())
id_token = { id_token = {
**grant.userinfo, **grant.userinfo,
"iss": self.issuer, "iss": self.issuer,
@ -166,6 +168,26 @@ class FakeOidcServer:
return self._sign(id_token) return self._sign(id_token)
def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
now = int(self._clock.time())
logout_token = {
"iss": self.issuer,
"aud": grant.client_id,
"iat": now,
"jti": random_string(10),
"events": {
"http://schemas.openid.net/event/backchannel-logout": {},
},
}
if grant.sid is not None:
logout_token["sid"] = grant.sid
if "sub" in grant.userinfo:
logout_token["sub"] = grant.userinfo["sub"]
return self._sign(logout_token)
def id_token_override(self, overrides: dict): def id_token_override(self, overrides: dict):
"""Temporarily patch the ID token generated by the token endpoint.""" """Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides) return patch.object(self, "_id_token_overrides", overrides)
@ -183,7 +205,8 @@ class FakeOidcServer:
code = random_string(10) code = random_string(10)
sid = None sid = None
if with_sid: if with_sid:
sid = random_string(10) sid = str(self.sid_counter)
self.sid_counter += 1
grant = FakeAuthorizationGrant( grant = FakeAuthorizationGrant(
userinfo=userinfo, userinfo=userinfo,