1045 lines
38 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
from authlib.jose import JsonWebToken
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.push.mailer import load_jinja2_templates
from synapse.server import HomeServer
from synapse.types import UserID, map_username_to_mxid_localpart
logger = logging.getLogger(__name__)
SESSION_COOKIE_NAME = b"oidc_session"
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
Token = TypedDict(
"Token",
{
"access_token": str,
"token_type": str,
"id_token": Optional[str],
"refresh_token": Optional[str],
"expires_in": int,
"scope": Optional[str],
},
)
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
#: A JWK Set, as per RFC7517 sec 5.
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: HomeServer):
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
hs.config.oidc_client_auth_method,
) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint,
token_endpoint=hs.config.oidc_token_endpoint,
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
jwks_uri=hs.config.oidc_jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = load_jinja2_templates(
hs.config.sso_template_dir, ["sso_error.html"]
)[0]
# identifier for the external_ids table
self._auth_provider_id = "oidc"
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.
This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def _validate_metadata(self):
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
everything is checked, only:
- ``issuer``
- ``authorization_endpoint``
- ``token_endpoint``
- ``response_types_supported`` (checks if "code" is in it)
- ``jwks_uri``
Raises:
ValueError: if something in the provider is not valid
"""
# Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
if self._skip_verification is True:
return
m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
if m.get("token_endpoint_auth_methods_supported") is not None:
m.validate_token_endpoint_auth_methods_supported()
if (
self._client_auth_method
not in m["token_endpoint_auth_methods_supported"]
):
raise ValueError(
'"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
auth_method=self._client_auth_method,
supported=m["token_endpoint_auth_methods_supported"],
)
)
if m.get("response_types_supported") is not None:
m.validate_response_types_supported()
if "code" not in m["response_types_supported"]:
raise ValueError(
'"code" not in "response_types_supported" (%r)'
% (m["response_types_supported"],)
)
# If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
if m.get("jwks") is None:
if m.get("jwks_uri") is not None:
m.validate_jwks_uri()
else:
raise ValueError('"jwks_uri" must be set')
@property
def _uses_userinfo(self) -> bool:
"""Returns True if the ``userinfo_endpoint`` should be used.
This is based on the requested scopes: if the scopes include
``openid``, the provider should give use an ID token containing the
user informations. If not, we should fetch them using the
``access_token`` with the ``userinfo_endpoint``.
"""
# Maybe that should be user-configurable and not inferred?
return "openid" not in self._scopes
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
The values metadatas are discovered if ``oidc_config.discovery`` is
``True`` and then cached.
Raises:
ValueError: if something in the provider is not valid
Returns:
The provider's metadata.
"""
# If we are using the OpenID Discovery documents, it needs to be loaded once
# FIXME: should there be a lock here?
if self._provider_needs_discovery:
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
metadata_response = await self._http_client.get_json(url)
# TODO: maybe update the other way around to let user override some values?
self._provider_metadata.update(metadata_response)
self._provider_needs_discovery = False
self._validate_metadata()
return self._provider_metadata
async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
If we're not using the ``userinfo_endpoint``, user infos are extracted
from the ID token, which is a JWT signed by keys given by the provider.
The keys are then cached.
Args:
force: Force reloading the keys.
Returns:
The key set
Looks like this::
{
'keys': [
{
'kid': 'abcdef',
'kty': 'RSA',
'alg': 'RS256',
'use': 'sig',
'e': 'XXXX',
'n': 'XXXX',
}
]
}
"""
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
# First check if the JWKS are loaded in the provider metadata.
# It can happen either if the provider gives its JWKS in the discovery
# document directly or if it was already loaded once.
metadata = await self.load_metadata()
jwk_set = metadata.get("jwks")
if jwk_set is not None and not force:
return jwk_set
# Loading the JWKS using the `jwks_uri` metadata
uri = metadata.get("jwks_uri")
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._http_client.get_json(uri)
# Caching the JWKS in the provider's metadata
self._provider_metadata["jwks"] = jwk_set
return jwk_set
async def _exchange_code(self, code: str) -> Token:
"""Exchange an authorization code for a token.
This calls the ``token_endpoint`` with the authorization code we
received in the callback to exchange it for a token. The call uses the
``ClientAuth`` to authenticate with the client with its ID and secret.
See:
https://tools.ietf.org/html/rfc6749#section-3.2
https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
Args:
code: The authorization code we got from the callback.
Returns:
A dict containing various tokens.
May look like this::
{
'token_type': 'bearer',
'access_token': 'abcdef',
'expires_in': 3599,
'id_token': 'ghijkl',
'refresh_token': 'mnopqr',
}
Raises:
OidcError: when the ``token_endpoint`` returned an error.
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
}
args = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self._callback_url,
}
body = urlencode(args, True)
# Fill the body/headers with credentials
uri, headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=headers, body=body
)
headers = {k: [v] for (k, v) in headers.items()}
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code and we do the body encoding ourself.
response = await self._http_client.request(
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
)
# This is used in multiple error messages below
status = "{code} {phrase}".format(
code=response.code, phrase=response.phrase.decode("utf-8")
)
resp_body = await make_deferred_yieldable(readBody(response))
if response.code >= 500:
# In case of a server error, we should first try to decode the body
# and check for an error field. If not, we respond with a generic
# error message.
try:
resp = json.loads(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
# Catch ValueError for the JSON decoding and KeyError for the "error" field
error = "server_error"
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=status),
)
raise OidcError(error, description)
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
resp = json.loads(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
# In case the authorization server responded with an error field,
# it should be a 4xx code. If not, warn about it but don't do
# anything special and report the original error message.
if response.code < 400:
logger.debug(
"Invalid response from the authorization server: "
'responded with a "{status}" '
"but body has an error field: {error!r}".format(
status=status, error=resp["error"]
)
)
description = resp.get("error_description", error)
raise OidcError(error, description)
# Now, this should not be an error. According to RFC6749 sec 5.1, it
# should be a 200 code. We're a bit more flexible than that, and will
# only throw on a 4xx code.
if response.code >= 400:
description = (
'Authorization server responded with a "{status}" error '
'but did not include an "error" field in its response.'.format(
status=status
)
)
logger.warning(description)
# Body was still valid JSON. Might be useful to log it for debugging.
logger.warning("Code exchange response: {resp!r}".format(resp=resp))
raise OidcError("server_error", description)
return resp
async def _fetch_userinfo(self, token: Token) -> UserInfo:
"""Fetch user informations from the ``userinfo_endpoint``.
Args:
token: the token given by the ``token_endpoint``.
Must include an ``access_token`` field.
Returns:
UserInfo: an object representing the user.
"""
metadata = await self.load_metadata()
resp = await self._http_client.get_json(
metadata["userinfo_endpoint"],
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
)
return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
token: the token given by the ``token_endpoint``.
Must include an ``id_token`` field.
nonce: the nonce value originally sent in the initial authorization
request. This value should match the one inside the token.
Returns:
An object representing the user.
"""
metadata = await self.load_metadata()
claims_params = {
"nonce": nonce,
"client_id": self._client_auth.client_id,
}
if "access_token" in token:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}}
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
token["id_token"],
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
token["id_token"],
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
async def handle_redirect_request(
self,
request: SynapseRequest,
client_redirect_url: bytes,
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
It returns a redirect to the authorization endpoint with a few
parameters:
- ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code``
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
In addition generating a redirect URL, we are setting a cookie with
a signed macaroon token containing the state, the nonce and the
client_redirect_url params. Those are then checked when the client
comes back from the provider.
Args:
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
when everything is done
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
The redirect URL to the authorization endpoint.
"""
state = generate_token()
nonce = generate_token()
cookie = self._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
)
request.addCookie(
SESSION_COOKIE_NAME,
cookie,
path="/_synapse/oidc",
max_age="3600",
httpOnly=True,
sameSite="lax",
)
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
return prepare_grant_uri(
authorization_endpoint,
client_id=self._client_auth.client_id,
response_type="code",
redirect_uri=self._callback_url,
scope=self._scopes,
state=state,
nonce=nonce,
)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
- once we known this session is legit, exchange the code with the
provider using the ``token_endpoint`` (see ``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
(``_fetch_userinfo``)
- map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
finish the login
Args:
request: the incoming request from the browser.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing")
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
(
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing")
return
logger.debug("Exchanging code")
code = request.args[b"code"][0].decode()
try:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
logger.debug("Fetching userinfo")
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e))
return
# Call the mapper to register/login the user
try:
user_id = await self._map_userinfo_to_user(userinfo, token)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
return
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session informations.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(
self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
It is then used as an `external_id`.
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
If a user already exists with the mxid we've mapped, raise an exception.
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
Raises:
MappingException: if there was an error while mapping some properties
Returns:
The mxid of the user
"""
try:
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id,
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)
except Exception as e:
raise MappingException(
"Could not extract user attributes from OIDC response: " + str(e)
)
logger.debug(
"Retrieved user attributes from user mapping provider: %r", attributes
)
if not attributes["localpart"]:
raise MappingException("localpart is empty")
localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname)
if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
# This mxid is taken
raise MappingException(
"mxid '{}' is already taken".format(user_id.to_string())
)
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=attributes["display_name"],
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id
UserAttribute = TypedDict(
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
class OidcMappingProvider(Generic[C]):
"""A mapping provider maps a UserInfo object to user attributes.
It should provide the API described by this class.
"""
def __init__(self, config: C):
"""
Args:
config: A custom config object from this module, parsed by ``parse_config()``
"""
@staticmethod
def parse_config(config: dict) -> C:
"""Parse the dict provided by the homeserver's config
Args:
config: A dictionary containing configuration options for this provider
Returns:
A custom config object for this module
"""
raise NotImplementedError()
def get_remote_user_id(self, userinfo: UserInfo) -> str:
"""Get a unique user ID for this user.
Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
Args:
userinfo: An object representing the user given by the OIDC provider
Returns:
A unique user ID
"""
raise NotImplementedError()
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
"""Map a ``UserInfo`` objects into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
"""
raise NotImplementedError()
# Used to clear out "None" values in templates
def jinja_finalize(thing):
return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize)
@attr.s
class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
"""An implementation of a mapping provider based on Jinja templates.
This is the default mapping provider.
"""
def __init__(self, config: JinjaOidcMappingConfig):
self._config = config
@staticmethod
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
if "localpart_template" not in config:
raise ConfigError(
"missing key: oidc_config.user_mapping_provider.config.localpart_template"
)
try:
localpart_template = env.from_string(config["localpart_template"])
except Exception as e:
raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
% (e,)
)
display_name_template = None # type: Optional[Template]
if "display_name_template" in config:
try:
display_name_template = env.from_string(config["display_name_template"])
except Exception as e:
raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
% (e,)
)
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
return userinfo[self._config.subject_claim]
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
localpart = self._config.localpart_template.render(user=userinfo).strip()
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
user=userinfo
).strip()
if display_name == "":
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)