)
+3. Add this Authorized redirect URI: `[synapse base url]/_synapse/oidc/callback`
+
+```yaml
+oidc_config:
+ enabled: true
+ issuer: "https://accounts.google.com/"
+ discover: true
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes:
+ - openid
+ - profile
+ user_mapping_provider:
+ config:
+ localpart_template: '{{ user.given_name|lower }}'
+ display_name_template: '{{ user.name }}'
+```
+
+### Twitch
+
+1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
+2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
+3. Add this OAuth Redirect URL: `[synapse base url]/_synapse/oidc/callback`
+
+```yaml
+oidc_config:
+ enabled: true
+ issuer: "https://id.twitch.tv/oauth2/"
+ discover: true
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ scopes:
+ - openid
+ user_mapping_provider:
+ config:
+ localpart_template: '{{ user.preferred_username }}'
+ display_name_template: '{{ user.name }}'
+```
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 98ead7dc0..1e397f773 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1470,6 +1470,94 @@ saml2_config:
#template_dir: "res/templates"
+# Enable OpenID Connect for registration and login. Uses authlib.
+#
+oidc_config:
+ # enable OpenID Connect. Defaults to false.
+ #
+ #enabled: true
+
+ # use the OIDC discovery mechanism to discover endpoints. Defaults to true.
+ #
+ #discover: true
+
+ # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
+ #
+ #issuer: "https://accounts.example.com/"
+
+ # oauth2 client id to use. Required.
+ #
+ #client_id: "provided-by-your-issuer"
+
+ # oauth2 client secret to use. Required.
+ #
+ #client_secret: "provided-by-your-issuer"
+
+ # auth method to use when exchanging the token.
+ # Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
+ #
+ #client_auth_method: "client_auth_basic"
+
+ # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
+ #
+ #scopes: ["openid"]
+
+ # the oauth2 authorization endpoint. Required if provider discovery is disabled.
+ #
+ #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+
+ # the oauth2 token endpoint. Required if provider discovery is disabled.
+ #
+ #token_endpoint: "https://accounts.example.com/oauth2/token"
+
+ # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
+ #
+ #userinfo_endpoint: "https://accounts.example.com/userinfo"
+
+ # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
+ #
+ #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+
+ # skip metadata verification. Defaults to false.
+ # Use this if you are connecting to a provider that is not OpenID Connect compliant.
+ # Avoid this in production.
+ #
+ #skip_verification: false
+
+
+ # An external module can be provided here as a custom solution to mapping
+ # attributes returned from a OIDC provider onto a matrix user.
+ #
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
+ #
+ #module: mapping_provider.OidcMappingProvider
+
+ # Custom configuration values for the module. Below options are intended
+ # for the built-in provider, they should be changed if using a custom
+ # module. This section will be passed as a Python dictionary to the
+ # module's `parse_config` method.
+ #
+ # Below is the config of the default mapping provider, based on Jinja2
+ # templates. Those templates are used to render user attributes, where the
+ # userinfo object is available through the `user` variable.
+ #
+ config:
+ # name of the claim containing a unique identifier for the user.
+ # Defaults to `sub`, which OpenID Connect compliant providers should provide.
+ #
+ #subject_claim: "sub"
+
+ # Jinja2 template for the localpart of the MXID
+ #
+ localpart_template: "{{ user.preferred_username }}"
+
+ # Jinja2 template for the display name to set on first login. Optional.
+ #
+ #display_name_template: "{{ user.given_name }} {{ user.last_name }}"
+
+
# Enable CAS for registration and login.
#
@@ -1554,6 +1642,13 @@ sso:
#
# This template has no additional variables.
#
+ # * HTML page to display to users if something goes wrong during the
+ # OpenID Connect authentication process: 'sso_error.html'.
+ #
+ # When rendering, this template is given two variables:
+ # * error: the technical name of the error
+ # * error_description: a human-readable message for the error
+ #
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
diff --git a/mypy.ini b/mypy.ini
index 69be2f67a..3533797d6 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -75,3 +75,6 @@ ignore_missing_imports = True
[mypy-jwt.*]
ignore_missing_imports = True
+
+[mypy-authlib.*]
+ignore_missing_imports = True
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index cbd1ea475..bc8695d8d 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -192,6 +192,11 @@ class SynapseHomeServer(HomeServer):
}
)
+ if self.get_config().oidc_enabled:
+ from synapse.rest.oidc import OIDCResource
+
+ resources["/_synapse/oidc"] = OIDCResource(self)
+
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
@@ -422,6 +427,13 @@ def setup(config_options):
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+ # Load the OIDC provider metadatas, if OIDC is enabled.
+ if hs.config.oidc_enabled:
+ oidc = hs.get_oidc_handler()
+ # Loading the provider metadata also ensures the provider config is valid.
+ yield defer.ensureDeferred(oidc.load_metadata())
+ yield defer.ensureDeferred(oidc.load_jwks())
+
_base.start(hs, config.listeners)
hs.get_datastore().db.updates.start_doing_background_updates()
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3053fc9d2..9e576060d 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -13,6 +13,7 @@ from synapse.config import (
key,
logger,
metrics,
+ oidc_config,
password,
password_auth_providers,
push,
@@ -59,6 +60,7 @@ class RootConfig:
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
sso: sso.SSOConfig
+ oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be6c6afa7..996d3e6bf 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -27,6 +27,7 @@ from .jwt_config import JWTConfig
from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
+from .oidc_config import OIDCConfig
from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
@@ -66,6 +67,7 @@ class HomeServerConfig(RootConfig):
AppServiceConfig,
KeyConfig,
SAML2Config,
+ OIDCConfig,
CasConfig,
SSOConfig,
JWTConfig,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
new file mode 100644
index 000000000..5af110745
--- /dev/null
+++ b/synapse/config/oidc_config.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.module_loader import load_module
+
+from ._base import Config, ConfigError
+
+DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
+
+
+class OIDCConfig(Config):
+ section = "oidc"
+
+ def read_config(self, config, **kwargs):
+ self.oidc_enabled = False
+
+ oidc_config = config.get("oidc_config")
+
+ if not oidc_config or not oidc_config.get("enabled", False):
+ return
+
+ try:
+ check_requirements("oidc")
+ except DependencyException as e:
+ raise ConfigError(e.message)
+
+ public_baseurl = self.public_baseurl
+ if public_baseurl is None:
+ raise ConfigError("oidc_config requires a public_baseurl to be set")
+ self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+
+ self.oidc_enabled = True
+ self.oidc_discover = oidc_config.get("discover", True)
+ self.oidc_issuer = oidc_config["issuer"]
+ self.oidc_client_id = oidc_config["client_id"]
+ self.oidc_client_secret = oidc_config["client_secret"]
+ self.oidc_client_auth_method = oidc_config.get(
+ "client_auth_method", "client_secret_basic"
+ )
+ self.oidc_scopes = oidc_config.get("scopes", ["openid"])
+ self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
+ self.oidc_token_endpoint = oidc_config.get("token_endpoint")
+ self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
+ self.oidc_jwks_uri = oidc_config.get("jwks_uri")
+ self.oidc_subject_claim = oidc_config.get("subject_claim", "sub")
+ self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+
+ ump_config = oidc_config.get("user_mapping_provider", {})
+ ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+ ump_config.setdefault("config", {})
+
+ (
+ self.oidc_user_mapping_provider_class,
+ self.oidc_user_mapping_provider_config,
+ ) = load_module(ump_config)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ required_methods = [
+ "get_remote_user_id",
+ "map_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.oidc_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by oidc_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ # Enable OpenID Connect for registration and login. Uses authlib.
+ #
+ oidc_config:
+ # enable OpenID Connect. Defaults to false.
+ #
+ #enabled: true
+
+ # use the OIDC discovery mechanism to discover endpoints. Defaults to true.
+ #
+ #discover: true
+
+ # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
+ #
+ #issuer: "https://accounts.example.com/"
+
+ # oauth2 client id to use. Required.
+ #
+ #client_id: "provided-by-your-issuer"
+
+ # oauth2 client secret to use. Required.
+ #
+ #client_secret: "provided-by-your-issuer"
+
+ # auth method to use when exchanging the token.
+ # Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
+ #
+ #client_auth_method: "client_auth_basic"
+
+ # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
+ #
+ #scopes: ["openid"]
+
+ # the oauth2 authorization endpoint. Required if provider discovery is disabled.
+ #
+ #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+
+ # the oauth2 token endpoint. Required if provider discovery is disabled.
+ #
+ #token_endpoint: "https://accounts.example.com/oauth2/token"
+
+ # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
+ #
+ #userinfo_endpoint: "https://accounts.example.com/userinfo"
+
+ # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
+ #
+ #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+
+ # skip metadata verification. Defaults to false.
+ # Use this if you are connecting to a provider that is not OpenID Connect compliant.
+ # Avoid this in production.
+ #
+ #skip_verification: false
+
+
+ # An external module can be provided here as a custom solution to mapping
+ # attributes returned from a OIDC provider onto a matrix user.
+ #
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ # Default is {mapping_provider!r}.
+ #
+ #module: mapping_provider.OidcMappingProvider
+
+ # Custom configuration values for the module. Below options are intended
+ # for the built-in provider, they should be changed if using a custom
+ # module. This section will be passed as a Python dictionary to the
+ # module's `parse_config` method.
+ #
+ # Below is the config of the default mapping provider, based on Jinja2
+ # templates. Those templates are used to render user attributes, where the
+ # userinfo object is available through the `user` variable.
+ #
+ config:
+ # name of the claim containing a unique identifier for the user.
+ # Defaults to `sub`, which OpenID Connect compliant providers should provide.
+ #
+ #subject_claim: "sub"
+
+ # Jinja2 template for the localpart of the MXID
+ #
+ localpart_template: "{{{{ user.preferred_username }}}}"
+
+ # Jinja2 template for the display name to set on first login. Optional.
+ #
+ #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
+ """.format(
+ mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
+ )
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index cac6bc013..aff642f01 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -36,17 +36,13 @@ class SSOConfig(Config):
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
- self.sso_redirect_confirm_template_dir = template_dir
+ self.sso_template_dir = template_dir
self.sso_account_deactivated_template = self.read_file(
- os.path.join(
- self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html"
- ),
+ os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
"sso_account_deactivated_template",
)
self.sso_auth_success_template = self.read_file(
- os.path.join(
- self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
- ),
+ os.path.join(self.sso_template_dir, "sso_auth_success.html"),
"sso_auth_success_template",
)
@@ -137,6 +133,13 @@ class SSOConfig(Config):
#
# This template has no additional variables.
#
+ # * HTML page to display to users if something goes wrong during the
+ # OpenID Connect authentication process: 'sso_error.html'.
+ #
+ # When rendering, this template is given two variables:
+ # * error: the technical name of the error
+ # * error_description: a human-readable message for the error
+ #
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7613e5b6a..f8d2331bf 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -126,13 +126,13 @@ class AuthHandler(BaseHandler):
# It notifies the user they are about to give access to their matrix account
# to the client.
self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
)[0]
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
+ hs.config.sso_template_dir, ["sso_auth_confirm.html"],
)[0]
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
new file mode 100644
index 000000000..178f26343
--- /dev/null
+++ b/synapse/handlers/oidc_handler.py
@@ -0,0 +1,998 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+from typing import Dict, Generic, List, Optional, Tuple, TypeVar
+from urllib.parse import urlencode
+
+import attr
+import pymacaroons
+from authlib.common.security import generate_token
+from authlib.jose import JsonWebToken
+from authlib.oauth2.auth import ClientAuth
+from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
+from jinja2 import Environment, Template
+from pymacaroons.exceptions import (
+ MacaroonDeserializationException,
+ MacaroonInvalidSignatureException,
+)
+from typing_extensions import TypedDict
+
+from twisted.web.client import readBody
+
+from synapse.config import ConfigError
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
+from synapse.push.mailer import load_jinja2_templates
+from synapse.server import HomeServer
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+SESSION_COOKIE_NAME = b"oidc_session"
+
+#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
+#: OpenID.Core sec 3.1.3.3.
+Token = TypedDict(
+ "Token",
+ {
+ "access_token": str,
+ "token_type": str,
+ "id_token": Optional[str],
+ "refresh_token": Optional[str],
+ "expires_in": int,
+ "scope": Optional[str],
+ },
+)
+
+#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
+#: there is no real point of doing this in our case.
+JWK = Dict[str, str]
+
+#: A JWK Set, as per RFC7517 sec 5.
+JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+
+
+class OidcError(Exception):
+ """Used to catch errors when calling the token_endpoint
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping the UserInfo object
+ """
+
+
+class OidcHandler:
+ """Handles requests related to the OpenID Connect login flow.
+ """
+
+ def __init__(self, hs: HomeServer):
+ self._callback_url = hs.config.oidc_callback_url # type: str
+ self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._client_auth = ClientAuth(
+ hs.config.oidc_client_id,
+ hs.config.oidc_client_secret,
+ hs.config.oidc_client_auth_method,
+ ) # type: ClientAuth
+ self._client_auth_method = hs.config.oidc_client_auth_method # type: str
+ self._subject_claim = hs.config.oidc_subject_claim
+ self._provider_metadata = OpenIDProviderMetadata(
+ issuer=hs.config.oidc_issuer,
+ authorization_endpoint=hs.config.oidc_authorization_endpoint,
+ token_endpoint=hs.config.oidc_token_endpoint,
+ userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
+ jwks_uri=hs.config.oidc_jwks_uri,
+ ) # type: OpenIDProviderMetadata
+ self._provider_needs_discovery = hs.config.oidc_discover # type: bool
+ self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
+ hs.config.oidc_user_mapping_provider_config
+ ) # type: OidcMappingProvider
+ self._skip_verification = hs.config.oidc_skip_verification # type: bool
+
+ self._http_client = hs.get_proxied_http_client()
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+ self._datastore = hs.get_datastore()
+ self._clock = hs.get_clock()
+ self._hostname = hs.hostname # type: str
+ self._server_name = hs.config.server_name # type: str
+ self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._error_template = load_jinja2_templates(
+ hs.config.sso_template_dir, ["sso_error.html"]
+ )[0]
+
+ # identifier for the external_ids table
+ self._auth_provider_id = "oidc"
+
+ def _render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and respond with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under ``synapse/res/templates/sso_error.html``.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error. Those include
+ well-known OAuth2/OIDC error types like invalid_request or
+ access_denied.
+ error_description: A human-readable description of the error.
+ """
+ html_bytes = self._error_template.render(
+ error=error, error_description=error_description
+ ).encode("utf-8")
+
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+ request.write(html_bytes)
+ finish_request(request)
+
+ def _validate_metadata(self):
+ """Verifies the provider metadata.
+
+ This checks the validity of the currently loaded provider. Not
+ everything is checked, only:
+
+ - ``issuer``
+ - ``authorization_endpoint``
+ - ``token_endpoint``
+ - ``response_types_supported`` (checks if "code" is in it)
+ - ``jwks_uri``
+
+ Raises:
+ ValueError: if something in the provider is not valid
+ """
+ # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
+ if self._skip_verification is True:
+ return
+
+ m = self._provider_metadata
+ m.validate_issuer()
+ m.validate_authorization_endpoint()
+ m.validate_token_endpoint()
+
+ if m.get("token_endpoint_auth_methods_supported") is not None:
+ m.validate_token_endpoint_auth_methods_supported()
+ if (
+ self._client_auth_method
+ not in m["token_endpoint_auth_methods_supported"]
+ ):
+ raise ValueError(
+ '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
+ auth_method=self._client_auth_method,
+ supported=m["token_endpoint_auth_methods_supported"],
+ )
+ )
+
+ if m.get("response_types_supported") is not None:
+ m.validate_response_types_supported()
+
+ if "code" not in m["response_types_supported"]:
+ raise ValueError(
+ '"code" not in "response_types_supported" (%r)'
+ % (m["response_types_supported"],)
+ )
+
+ # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ if self._uses_userinfo:
+ if m.get("userinfo_endpoint") is None:
+ raise ValueError(
+ 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ )
+ else:
+ # If we're not using userinfo, we need a valid jwks to validate the ID token
+ if m.get("jwks") is None:
+ if m.get("jwks_uri") is not None:
+ m.validate_jwks_uri()
+ else:
+ raise ValueError('"jwks_uri" must be set')
+
+ @property
+ def _uses_userinfo(self) -> bool:
+ """Returns True if the ``userinfo_endpoint`` should be used.
+
+ This is based on the requested scopes: if the scopes include
+ ``openid``, the provider should give use an ID token containing the
+ user informations. If not, we should fetch them using the
+ ``access_token`` with the ``userinfo_endpoint``.
+ """
+
+ # Maybe that should be user-configurable and not inferred?
+ return "openid" not in self._scopes
+
+ async def load_metadata(self) -> OpenIDProviderMetadata:
+ """Load and validate the provider metadata.
+
+ The values metadatas are discovered if ``oidc_config.discovery`` is
+ ``True`` and then cached.
+
+ Raises:
+ ValueError: if something in the provider is not valid
+
+ Returns:
+ The provider's metadata.
+ """
+ # If we are using the OpenID Discovery documents, it needs to be loaded once
+ # FIXME: should there be a lock here?
+ if self._provider_needs_discovery:
+ url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+ metadata_response = await self._http_client.get_json(url)
+ # TODO: maybe update the other way around to let user override some values?
+ self._provider_metadata.update(metadata_response)
+ self._provider_needs_discovery = False
+
+ self._validate_metadata()
+
+ return self._provider_metadata
+
+ async def load_jwks(self, force: bool = False) -> JWKS:
+ """Load the JSON Web Key Set used to sign ID tokens.
+
+ If we're not using the ``userinfo_endpoint``, user infos are extracted
+ from the ID token, which is a JWT signed by keys given by the provider.
+ The keys are then cached.
+
+ Args:
+ force: Force reloading the keys.
+
+ Returns:
+ The key set
+
+ Looks like this::
+
+ {
+ 'keys': [
+ {
+ 'kid': 'abcdef',
+ 'kty': 'RSA',
+ 'alg': 'RS256',
+ 'use': 'sig',
+ 'e': 'XXXX',
+ 'n': 'XXXX',
+ }
+ ]
+ }
+ """
+ if self._uses_userinfo:
+ # We're not using jwt signing, return an empty jwk set
+ return {"keys": []}
+
+ # First check if the JWKS are loaded in the provider metadata.
+ # It can happen either if the provider gives its JWKS in the discovery
+ # document directly or if it was already loaded once.
+ metadata = await self.load_metadata()
+ jwk_set = metadata.get("jwks")
+ if jwk_set is not None and not force:
+ return jwk_set
+
+ # Loading the JWKS using the `jwks_uri` metadata
+ uri = metadata.get("jwks_uri")
+ if not uri:
+ raise RuntimeError('Missing "jwks_uri" in metadata')
+
+ jwk_set = await self._http_client.get_json(uri)
+
+ # Caching the JWKS in the provider's metadata
+ self._provider_metadata["jwks"] = jwk_set
+ return jwk_set
+
+ async def _exchange_code(self, code: str) -> Token:
+ """Exchange an authorization code for a token.
+
+ This calls the ``token_endpoint`` with the authorization code we
+ received in the callback to exchange it for a token. The call uses the
+ ``ClientAuth`` to authenticate with the client with its ID and secret.
+
+ Args:
+ code: The autorization code we got from the callback.
+
+ Returns:
+ A dict containing various tokens.
+
+ May look like this::
+
+ {
+ 'token_type': 'bearer',
+ 'access_token': 'abcdef',
+ 'expires_in': 3599,
+ 'id_token': 'ghijkl',
+ 'refresh_token': 'mnopqr',
+ }
+
+ Raises:
+ OidcError: when the ``token_endpoint`` returned an error.
+ """
+ metadata = await self.load_metadata()
+ token_endpoint = metadata.get("token_endpoint")
+ headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "User-Agent": self._http_client.user_agent,
+ "Accept": "application/json",
+ }
+
+ args = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": self._callback_url,
+ }
+ body = urlencode(args, True)
+
+ # Fill the body/headers with credentials
+ uri, headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=headers, body=body
+ )
+ headers = {k: [v] for (k, v) in headers.items()}
+
+ # Do the actual request
+ # We're not using the SimpleHttpClient util methods as we don't want to
+ # check the HTTP status code and we do the body encoding ourself.
+ response = await self._http_client.request(
+ method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+ )
+
+ # This is used in multiple error messages below
+ status = "{code} {phrase}".format(
+ code=response.code, phrase=response.phrase.decode("utf-8")
+ )
+
+ resp_body = await readBody(response)
+
+ if response.code >= 500:
+ # In case of a server error, we should first try to decode the body
+ # and check for an error field. If not, we respond with a generic
+ # error message.
+ try:
+ resp = json.loads(resp_body.decode("utf-8"))
+ error = resp["error"]
+ description = resp.get("error_description", error)
+ except (ValueError, KeyError):
+ # Catch ValueError for the JSON decoding and KeyError for the "error" field
+ error = "server_error"
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=status),
+ )
+
+ raise OidcError(error, description)
+
+ # Since it is a not a 5xx code, body should be a valid JSON. It will
+ # raise if not.
+ resp = json.loads(resp_body.decode("utf-8"))
+
+ if "error" in resp:
+ error = resp["error"]
+ # In case the authorization server responded with an error field,
+ # it should be a 4xx code. If not, warn about it but don't do
+ # anything special and report the original error message.
+ if response.code < 400:
+ logger.debug(
+ "Invalid response from the authorization server: "
+ 'responded with a "{status}" '
+ "but body has an error field: {error!r}".format(
+ status=status, error=resp["error"]
+ )
+ )
+
+ description = resp.get("error_description", error)
+ raise OidcError(error, description)
+
+ # Now, this should not be an error. According to RFC6749 sec 5.1, it
+ # should be a 200 code. We're a bit more flexible than that, and will
+ # only throw on a 4xx code.
+ if response.code >= 400:
+ description = (
+ 'Authorization server responded with a "{status}" error '
+ 'but did not include an "error" field in its response.'.format(
+ status=status
+ )
+ )
+ logger.warning(description)
+ # Body was still valid JSON. Might be useful to log it for debugging.
+ logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+ raise OidcError("server_error", description)
+
+ return resp
+
+ async def _fetch_userinfo(self, token: Token) -> UserInfo:
+ """Fetch user informations from the ``userinfo_endpoint``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``access_token`` field.
+
+ Returns:
+ UserInfo: an object representing the user.
+ """
+ metadata = await self.load_metadata()
+
+ resp = await self._http_client.get_json(
+ metadata["userinfo_endpoint"],
+ headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
+ )
+
+ return UserInfo(resp)
+
+ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+ """Return an instance of UserInfo from token's ``id_token``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``id_token`` field.
+ nonce: the nonce value originally sent in the initial authorization
+ request. This value should match the one inside the token.
+
+ Returns:
+ An object representing the user.
+ """
+ metadata = await self.load_metadata()
+ claims_params = {
+ "nonce": nonce,
+ "client_id": self._client_auth.client_id,
+ }
+ if "access_token" in token:
+ # If we got an `access_token`, there should be an `at_hash` claim
+ # in the `id_token` that we can check against.
+ claims_params["access_token"] = token["access_token"]
+ claims_cls = CodeIDToken
+ else:
+ claims_cls = ImplicitIDToken
+
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+ jwt = JsonWebToken(alg_values)
+
+ claim_options = {"iss": {"values": [metadata["issuer"]]}}
+
+ # Try to decode the keys in cache first, then retry by forcing the keys
+ # to be reloaded
+ jwk_set = await self.load_jwks()
+ try:
+ claims = jwt.decode(
+ token["id_token"],
+ key=jwk_set,
+ claims_cls=claims_cls,
+ claims_options=claim_options,
+ claims_params=claims_params,
+ )
+ except ValueError:
+ jwk_set = await self.load_jwks(force=True) # try reloading the jwks
+ claims = jwt.decode(
+ token["id_token"],
+ key=jwk_set,
+ claims_cls=claims_cls,
+ claims_options=claim_options,
+ claims_params=claims_params,
+ )
+
+ claims.validate(leeway=120) # allows 2 min of clock skew
+ return UserInfo(claims)
+
+ async def handle_redirect_request(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> None:
+ """Handle an incoming request to /login/sso/redirect
+
+ It redirects the browser to the authorization endpoint with a few
+ parameters:
+
+ - ``client_id``: the client ID set in ``oidc_config.client_id``
+ - ``response_type``: ``code``
+ - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+ - ``scope``: the list of scopes set in ``oidc_config.scopes``
+ - ``state``: a random string
+ - ``nonce``: a random string
+
+ In addition to redirecting the client, we are setting a cookie with
+ a signed macaroon token containing the state, the nonce and the
+ client_redirect_url params. Those are then checked when the client
+ comes back from the provider.
+
+
+ Args:
+ request: the incoming request from the browser.
+ We'll respond to it with a redirect and a cookie.
+ client_redirect_url: the URL that we should redirect the client to
+ when everything is done
+ """
+
+ state = generate_token()
+ nonce = generate_token()
+
+ cookie = self._generate_oidc_session_token(
+ state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+ )
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ cookie,
+ path="/_synapse/oidc",
+ max_age="3600",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ metadata = await self.load_metadata()
+ authorization_endpoint = metadata.get("authorization_endpoint")
+ uri = prepare_grant_uri(
+ authorization_endpoint,
+ client_id=self._client_auth.client_id,
+ response_type="code",
+ redirect_uri=self._callback_url,
+ scope=self._scopes,
+ state=state,
+ nonce=nonce,
+ )
+ request.redirect(uri)
+ finish_request(request)
+
+ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/oidc/callback
+
+ Since we might want to display OIDC-related errors in a user-friendly
+ way, we don't raise SynapseError from here. Instead, we call
+ ``self._render_error`` which displays an HTML page for the error.
+
+ Most of the OpenID Connect logic happens here:
+
+ - first, we check if there was any error returned by the provider and
+ display it
+ - then we fetch the session cookie, decode and verify it
+ - the ``state`` query parameter should match with the one stored in the
+ session cookie
+ - once we known this session is legit, exchange the code with the
+ provider using the ``token_endpoint`` (see ``_exchange_code``)
+ - once we have the token, use it to either extract the UserInfo from
+ the ``id_token`` (``_parse_id_token``), or use the ``access_token``
+ to fetch UserInfo from the ``userinfo_endpoint``
+ (``_fetch_userinfo``)
+ - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
+ finish the login
+
+ Args:
+ request: the incoming request from the browser.
+ """
+
+ # The provider might redirect with an error.
+ # In that case, just display it as-is.
+ if b"error" in request.args:
+ error = request.args[b"error"][0].decode()
+ description = request.args.get(b"error_description", [b""])[0].decode()
+
+ # Most of the errors returned by the provider could be due by
+ # either the provider misbehaving or Synapse being misconfigured.
+ # The only exception of that is "access_denied", where the user
+ # probably cancelled the login flow. In other cases, log those errors.
+ if error != "access_denied":
+ logger.error("Error from the OIDC provider: %s %s", error, description)
+
+ self._render_error(request, error, description)
+ return
+
+ # Fetch the session cookie
+ session = request.getCookie(SESSION_COOKIE_NAME)
+ if session is None:
+ logger.info("No session cookie found")
+ self._render_error(request, "missing_session", "No session cookie found")
+ return
+
+ # Remove the cookie. There is a good chance that if the callback failed
+ # once, it will fail next time and the code will already be exchanged.
+ # Removing it early avoids spamming the provider with token requests.
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ b"",
+ path="/_synapse/oidc",
+ expires="Thu, Jan 01 1970 00:00:00 UTC",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ # Check for the state query parameter
+ if b"state" not in request.args:
+ logger.info("State parameter is missing")
+ self._render_error(request, "invalid_request", "State parameter is missing")
+ return
+
+ state = request.args[b"state"][0].decode()
+
+ # Deserialize the session token and verify it.
+ try:
+ nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+ except MacaroonDeserializationException as e:
+ logger.exception("Invalid session")
+ self._render_error(request, "invalid_session", str(e))
+ return
+ except MacaroonInvalidSignatureException as e:
+ logger.exception("Could not verify session")
+ self._render_error(request, "mismatching_session", str(e))
+ return
+
+ # Exchange the code with the provider
+ if b"code" not in request.args:
+ logger.info("Code parameter is missing")
+ self._render_error(request, "invalid_request", "Code parameter is missing")
+ return
+
+ logger.info("Exchanging code")
+ code = request.args[b"code"][0].decode()
+ try:
+ token = await self._exchange_code(code)
+ except OidcError as e:
+ logger.exception("Could not exchange code")
+ self._render_error(request, e.error, e.error_description)
+ return
+
+ # Now that we have a token, get the userinfo, either by decoding the
+ # `id_token` or by fetching the `userinfo_endpoint`.
+ if self._uses_userinfo:
+ logger.info("Fetching userinfo")
+ try:
+ userinfo = await self._fetch_userinfo(token)
+ except Exception as e:
+ logger.exception("Could not fetch userinfo")
+ self._render_error(request, "fetch_error", str(e))
+ return
+ else:
+ logger.info("Extracting userinfo from id_token")
+ try:
+ userinfo = await self._parse_id_token(token, nonce=nonce)
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._render_error(request, "invalid_token", str(e))
+ return
+
+ # Call the mapper to register/login the user
+ try:
+ user_id = await self._map_userinfo_to_user(userinfo, token)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._render_error(request, "mapping_error", str(e))
+ return
+
+ # and finally complete the login
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url
+ )
+
+ def _generate_oidc_session_token(
+ self,
+ state: str,
+ nonce: str,
+ client_redirect_url: str,
+ duration_in_ms: int = (60 * 60 * 1000),
+ ) -> str:
+ """Generates a signed token storing data about an OIDC session.
+
+ When Synapse initiates an authorization flow, it creates a random state
+ and a random nonce. Those parameters are given to the provider and
+ should be verified when the client comes back from the provider.
+ It is also used to store the client_redirect_url, which is used to
+ complete the SSO login flow.
+
+ Args:
+ state: The ``state`` parameter passed to the OIDC provider.
+ nonce: The ``nonce`` parameter passed to the OIDC provider.
+ client_redirect_url: The URL the client gave when it initiated the
+ flow.
+ duration_in_ms: An optional duration for the token in milliseconds.
+ Defaults to an hour.
+
+ Returns:
+ A signed macaroon token with the session informations.
+ """
+ macaroon = pymacaroons.Macaroon(
+ location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+ )
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = session")
+ macaroon.add_first_party_caveat("state = %s" % (state,))
+ macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+ macaroon.add_first_party_caveat(
+ "client_redirect_url = %s" % (client_redirect_url,)
+ )
+ now = self._clock.time_msec()
+ expiry = now + duration_in_ms
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ return macaroon.serialize()
+
+ def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+ """Verifies and extract an OIDC session token.
+
+ This verifies that a given session token was issued by this homeserver
+ and extract the nonce and client_redirect_url caveats.
+
+ Args:
+ session: The session token to verify
+ state: The state the OIDC provider gave back
+
+ Returns:
+ The nonce and the client_redirect_url for this session
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(session)
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = session")
+ v.satisfy_exact("state = %s" % (state,))
+ v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ v.satisfy_general(self._verify_expiry)
+
+ v.verify(macaroon, self._macaroon_secret_key)
+
+ # Extract the `nonce` and `client_redirect_url` from the token
+ nonce = self._get_value_from_macaroon(macaroon, "nonce")
+ client_redirect_url = self._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ return nonce, client_redirect_url
+
+ def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ Exception: if the caveat was not in the macaroon
+ """
+ prefix = key + " = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(prefix):
+ return caveat.caveat_id[len(prefix) :]
+ raise Exception("No %s caveat in macaroon" % (key,))
+
+ def _verify_expiry(self, caveat: str) -> bool:
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ now = self._clock.time_msec()
+ return now < expiry
+
+ async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ """Maps a UserInfo object to a mxid.
+
+ UserInfo should have a claim that uniquely identifies users. This claim
+ is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+ It is then used as an `external_id`.
+
+ If we don't find the user that way, we should register the user,
+ mapping the localpart and the display name from the UserInfo.
+
+ If a user already exists with the mxid we've mapped, raise an exception.
+
+ Args:
+ userinfo: an object representing the user
+ token: a dict with the tokens obtained from the provider
+
+ Raises:
+ MappingException: if there was an error while mapping some properties
+
+ Returns:
+ The mxid of the user
+ """
+ try:
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ except Exception as e:
+ raise MappingException(
+ "Failed to extract subject from OIDC response: %s" % (e,)
+ )
+
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ self._auth_provider_id,
+ remote_user_id,
+ )
+
+ registered_user_id = await self._datastore.get_user_by_external_id(
+ self._auth_provider_id, remote_user_id,
+ )
+
+ if registered_user_id is not None:
+ logger.info("Found existing mapping %s", registered_user_id)
+ return registered_user_id
+
+ try:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token
+ )
+ except Exception as e:
+ raise MappingException(
+ "Could not extract user attributes from OIDC response: " + str(e)
+ )
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r", attributes
+ )
+
+ if not attributes["localpart"]:
+ raise MappingException("localpart is empty")
+
+ localpart = map_username_to_mxid_localpart(attributes["localpart"])
+
+ user_id = UserID(localpart, self._hostname)
+ if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
+ # This mxid is taken
+ raise MappingException(
+ "mxid '{}' is already taken".format(user_id.to_string())
+ )
+
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=attributes["display_name"],
+ )
+
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id,
+ )
+ return registered_user_id
+
+
+UserAttribute = TypedDict(
+ "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+)
+C = TypeVar("C")
+
+
+class OidcMappingProvider(Generic[C]):
+ """A mapping provider maps a UserInfo object to user attributes.
+
+ It should provide the API described by this class.
+ """
+
+ def __init__(self, config: C):
+ """
+ Args:
+ config: A custom config object from this module, parsed by ``parse_config()``
+ """
+
+ @staticmethod
+ def parse_config(config: dict) -> C:
+ """Parse the dict provided by the homeserver's config
+
+ Args:
+ config: A dictionary containing configuration options for this provider
+
+ Returns:
+ A custom config object for this module
+ """
+ raise NotImplementedError()
+
+ def get_remote_user_id(self, userinfo: UserInfo) -> str:
+ """Get a unique user ID for this user.
+
+ Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+
+ Returns:
+ A unique user ID
+ """
+ raise NotImplementedError()
+
+ async def map_user_attributes(
+ self, userinfo: UserInfo, token: Token
+ ) -> UserAttribute:
+ """Map a ``UserInfo`` objects into user attributes.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing the ``localpart`` and (optionally) the ``display_name``
+ """
+ raise NotImplementedError()
+
+
+# Used to clear out "None" values in templates
+def jinja_finalize(thing):
+ return thing if thing is not None else ""
+
+
+env = Environment(finalize=jinja_finalize)
+
+
+@attr.s
+class JinjaOidcMappingConfig:
+ subject_claim = attr.ib() # type: str
+ localpart_template = attr.ib() # type: Template
+ display_name_template = attr.ib() # type: Optional[Template]
+
+
+class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
+ """An implementation of a mapping provider based on Jinja templates.
+
+ This is the default mapping provider.
+ """
+
+ def __init__(self, config: JinjaOidcMappingConfig):
+ self._config = config
+
+ @staticmethod
+ def parse_config(config: dict) -> JinjaOidcMappingConfig:
+ subject_claim = config.get("subject_claim", "sub")
+
+ if "localpart_template" not in config:
+ raise ConfigError(
+ "missing key: oidc_config.user_mapping_provider.config.localpart_template"
+ )
+
+ try:
+ localpart_template = env.from_string(config["localpart_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
+ % (e,)
+ )
+
+ display_name_template = None # type: Optional[Template]
+ if "display_name_template" in config:
+ try:
+ display_name_template = env.from_string(config["display_name_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
+ % (e,)
+ )
+
+ return JinjaOidcMappingConfig(
+ subject_claim=subject_claim,
+ localpart_template=localpart_template,
+ display_name_template=display_name_template,
+ )
+
+ def get_remote_user_id(self, userinfo: UserInfo) -> str:
+ return userinfo[self._config.subject_claim]
+
+ async def map_user_attributes(
+ self, userinfo: UserInfo, token: Token
+ ) -> UserAttribute:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
+
+ display_name = None # type: Optional[str]
+ if self._config.display_name_template is not None:
+ display_name = self._config.display_name_template.render(
+ user=userinfo
+ ).strip()
+
+ if display_name == "":
+ display_name = None
+
+ return UserAttribute(localpart=localpart, display_name=display_name)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 379754582..58eb47c69 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -359,6 +359,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
+ b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
@@ -399,6 +400,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
+ b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
@@ -434,6 +436,10 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
+ actual_headers = {b"Accept": [b"application/json"]}
+ if headers:
+ actual_headers.update(headers)
+
body = yield self.get_raw(uri, args, headers=headers)
return json.loads(body)
@@ -467,6 +473,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
+ b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 39c99a280..8b4312e5a 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = {
'eliot<1.8.0;python_version<"3.5.3"',
],
"saml2": ["pysaml2>=4.5.0"],
+ "oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
new file mode 100644
index 000000000..43a211386
--- /dev/null
+++ b/synapse/res/templates/sso_error.html
@@ -0,0 +1,18 @@
+
+
+
+
+ SSO error
+
+
+ Oops! Something went wrong during authentication.
+
+ Try logging in again from your Matrix client and if the problem persists
+ please contact the server's administrator.
+
+ Error: {{ error }}
+ {% if error_description %}
+ {{ error_description }}
+ {% endif %}
+
+
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4de2f97d0..de7eca21f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
+ self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.oidc_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
return self._saml_handler.handle_redirect_request(client_redirect_url)
+class OIDCRedirectServlet(RestServlet):
+ """Implementation for /login/sso/redirect for the OIDC login flow."""
+
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
+
+ def __init__(self, hs):
+ self._oidc_handler = hs.get_oidc_handler()
+
+ async def on_GET(self, request):
+ args = request.args
+ if b"redirectUrl" not in args:
+ return 400, "Redirect URL not specified for SSO auth"
+ client_redirect_url = args[b"redirectUrl"][0]
+ await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+
+
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
@@ -472,3 +492,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
+ elif hs.config.oidc_enabled:
+ OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py
new file mode 100644
index 000000000..d958dd65b
--- /dev/null
+++ b/synapse/rest/oidc/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCResource(Resource):
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
new file mode 100644
index 000000000..c03194f00
--- /dev/null
+++ b/synapse/rest/oidc/callback_resource.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.http.server import DirectServeResource, wrap_html_request_handler
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCCallbackResource(DirectServeResource):
+ isLeaf = 1
+
+ def __init__(self, hs):
+ super().__init__()
+ self._oidc_handler = hs.get_oidc_handler()
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/server.py b/synapse/server.py
index bf97a16c0..b4aea81e2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -204,6 +204,7 @@ class HomeServer(object):
"account_validity_handler",
"cas_handler",
"saml_handler",
+ "oidc_handler",
"event_client_serializer",
"password_policy_handler",
"storage",
@@ -562,6 +563,11 @@ class HomeServer(object):
return SamlHandler(self)
+ def build_oidc_handler(self):
+ from synapse.handlers.oidc_handler import OidcHandler
+
+ return OidcHandler(self)
+
def build_event_client_serializer(self):
return EventClientSerializer(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 18043a259..31a9cc038 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -13,6 +13,7 @@ import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.message
import synapse.handlers.presence
+import synapse.handlers.register
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
@@ -128,3 +129,7 @@ class HomeServer(object):
pass
def get_storage(self) -> synapse.storage.Storage:
pass
+ def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler:
+ pass
+ def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator:
+ pass
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
new file mode 100644
index 000000000..61963aa90
--- /dev/null
+++ b/tests/handlers/test_oidc.py
@@ -0,0 +1,565 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from urllib.parse import parse_qs, urlparse
+
+from mock import Mock, patch
+
+import attr
+import pymacaroons
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+
+from synapse.handlers.oidc_handler import (
+ MappingException,
+ OidcError,
+ OidcHandler,
+ OidcMappingProvider,
+)
+from synapse.types import UserID
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+@attr.s
+class FakeResponse:
+ code = attr.ib()
+ body = attr.ib()
+ phrase = attr.ib()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
+
+
+# These are a few constants that are used as config parameters in the tests.
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+
+# config for common cases
+COMMON_CONFIG = {
+ "discover": False,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+}
+
+
+# The cookie name and path don't really matter, just that it has to be coherent
+# between the callback & redirect handlers.
+COOKIE_NAME = b"oidc_session"
+COOKIE_PATH = "/_synapse/oidc"
+
+MockedMappingProvider = Mock(OidcMappingProvider)
+
+
+def simple_async_mock(return_value=None, raises=None):
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+async def get_json(url):
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+
+class OidcHandlerTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ config = self.default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = config.get("oidc_config", {})
+ oidc_config["enabled"] = True
+ oidc_config["client_id"] = CLIENT_ID
+ oidc_config["client_secret"] = CLIENT_SECRET
+ oidc_config["issuer"] = ISSUER
+ oidc_config["scopes"] = SCOPES
+ oidc_config["user_mapping_provider"] = {
+ "module": __name__ + ".MockedMappingProvider"
+ }
+ config["oidc_config"] = oidc_config
+
+ hs = self.setup_test_homeserver(
+ http_client=self.http_client,
+ proxied_http_client=self.http_client,
+ config=config,
+ )
+
+ self.handler = OidcHandler(hs)
+
+ return hs
+
+ def metadata_edit(self, values):
+ return patch.dict(self.handler._provider_metadata, values)
+
+ def assertRenderedError(self, error, error_description=None):
+ args = self.handler._render_error.call_args[0]
+ self.assertEqual(args[1], error)
+ if error_description is not None:
+ self.assertEqual(args[2], error_description)
+ # Reset the render_error mock
+ self.handler._render_error.reset_mock()
+
+ def test_config(self):
+ """Basic config correctly sets up the callback URL and client auth correctly."""
+ self.assertEqual(self.handler._callback_url, CALLBACK_URL)
+ self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+
+ @override_config({"oidc_config": {"discover": True}})
+ @defer.inlineCallbacks
+ def test_discovery(self):
+ """The handler should discover the endpoints from OIDC discovery document."""
+ # This would throw if some metadata were invalid
+ metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+
+ self.assertEqual(metadata.issuer, ISSUER)
+ self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
+ self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
+ self.assertEqual(metadata.jwks_uri, JWKS_URI)
+ # FIXME: it seems like authlib does not have that defined in its metadata models
+ # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+
+ # subsequent calls should be cached
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_no_discovery(self):
+ """When discovery is disabled, it should not try to load from discovery document."""
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_load_jwks(self):
+ """JWKS loading is done once (then cached) if used."""
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.assertEqual(jwks, {"keys": []})
+
+ # subsequent calls should be cached…
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_not_called()
+
+ # …unless forced
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+
+ # Throw if the JWKS uri is missing
+ with self.metadata_edit({"jwks_uri": None}):
+ with self.assertRaises(RuntimeError):
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+
+ # Return empty key set if JWKS are not used
+ self.handler._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ def test_validate_config(self):
+ """Provider metadatas are extensively validated."""
+ h = self.handler
+
+ # Default test config does not throw
+ h._validate_metadata()
+
+ with self.metadata_edit({"issuer": None}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "http://insecure/"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"authorization_endpoint": None}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"token_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": None}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"response_types_supported": ["id_token"]}):
+ self.assertRaisesRegex(
+ ValueError, "response_types_supported", h._validate_metadata
+ )
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
+ ):
+ # should not throw, as client_secret_basic is the default auth method
+ h._validate_metadata()
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
+ ):
+ self.assertRaisesRegex(
+ ValueError,
+ "token_endpoint_auth_methods_supported",
+ h._validate_metadata,
+ )
+
+ # Tests for configs that the userinfo endpoint
+ self.assertFalse(h._uses_userinfo)
+ h._scopes = [] # do not request the openid scope
+ self.assertTrue(h._uses_userinfo)
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+
+ with self.metadata_edit(
+ {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
+ ):
+ # Shouldn't raise with a valid userinfo, even without
+ h._validate_metadata()
+
+ @override_config({"oidc_config": {"skip_verification": True}})
+ def test_skip_verification(self):
+ """Provider metadata validation can be disabled by config."""
+ with self.metadata_edit({"issuer": "http://insecure"}):
+ # This should not throw
+ self.handler._validate_metadata()
+
+ @defer.inlineCallbacks
+ def test_redirect_request(self):
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["addCookie", "redirect", "finish"])
+ yield defer.ensureDeferred(
+ self.handler.handle_redirect_request(req, b"http://client/redirect")
+ )
+ url = req.redirect.call_args[0][0]
+ url = urlparse(url)
+ auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+
+ self.assertEqual(url.scheme, auth_endpoint.scheme)
+ self.assertEqual(url.netloc, auth_endpoint.netloc)
+ self.assertEqual(url.path, auth_endpoint.path)
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(params["response_type"], ["code"])
+ self.assertEqual(params["scope"], [" ".join(SCOPES)])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(len(params["state"]), 1)
+ self.assertEqual(len(params["nonce"]), 1)
+
+ # Check what is in the cookie
+ # note: python3.5 mock does not have the .called_once() method
+ calls = req.addCookie.call_args_list
+ self.assertEqual(len(calls), 1) # called once
+ # For some reason, call.args does not work with python3.5
+ args = calls[0][0]
+ kwargs = calls[0][1]
+ self.assertEqual(args[0], COOKIE_NAME)
+ self.assertEqual(kwargs["path"], COOKIE_PATH)
+ cookie = args[1]
+
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ state = self.handler._get_value_from_macaroon(macaroon, "state")
+ nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
+ redirect = self.handler._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ self.assertEqual(params["state"], [state])
+ self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(redirect, "http://client/redirect")
+
+ @defer.inlineCallbacks
+ def test_callback_error(self):
+ """Errors from the provider returned in the callback are displayed."""
+ self.handler._render_error = Mock()
+ request = Mock(args={})
+ request.args[b"error"] = [b"invalid_client"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "")
+
+ request.args[b"error_description"] = [b"some description"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "some description")
+
+ @defer.inlineCallbacks
+ def test_callback(self):
+ """Code callback works and display errors if something went wrong.
+
+ A lot of scenarios are tested here:
+ - when the callback works, with userinfo from ID token
+ - when the user mapping fails
+ - when ID token verification fails
+ - when the callback works, with userinfo fetched from the userinfo endpoint
+ - when the userinfo fetching fails
+ - when the code exchange fails
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "preferred_username": "bar",
+ }
+ user_id = UserID("foo", "domain.org")
+ self.handler._render_error = Mock(return_value=None)
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ code = "code"
+ state = "state"
+ nonce = "nonce"
+ client_redirect_url = "http://client/redirect"
+ session = self.handler._generate_oidc_session_token(
+ state=state, nonce=nonce, client_redirect_url=client_redirect_url,
+ )
+ request.getCookie.return_value = session
+
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_not_called()
+ self.handler._render_error.assert_not_called()
+
+ # Handle mapping errors
+ self.handler._map_userinfo_to_user = simple_async_mock(
+ raises=MappingException()
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+
+ # Handle ID token errors
+ self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_token")
+
+ self.handler._auth_handler.complete_sso_login.reset_mock()
+ self.handler._exchange_code.reset_mock()
+ self.handler._parse_id_token.reset_mock()
+ self.handler._map_userinfo_to_user.reset_mock()
+ self.handler._fetch_userinfo.reset_mock()
+
+ # With userinfo fetching
+ self.handler._scopes = [] # do not ask the "openid" scope
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_not_called()
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.handler._render_error.assert_not_called()
+
+ # Handle userinfo fetching error
+ self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("fetch_error")
+
+ # Handle code exchange failure
+ self.handler._exchange_code = simple_async_mock(
+ raises=OidcError("invalid_request")
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @defer.inlineCallbacks
+ def test_callback_session(self):
+ """The callback verifies the session presence and validity"""
+ self.handler._render_error = Mock(return_value=None)
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ # Missing cookie
+ request.args = {}
+ request.getCookie.return_value = None
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("missing_session", "No session cookie found")
+
+ # Missing session parameter
+ request.args = {}
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request", "State parameter is missing")
+
+ # Invalid cookie
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_session")
+
+ # Mismatching session
+ session = self.handler._generate_oidc_session_token(
+ state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+ )
+ request.args = {}
+ request.args[b"state"] = [b"mismatching state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mismatching_session")
+
+ # Valid session
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @defer.inlineCallbacks
+ def test_exchange_code(self):
+ """Code exchange behaves correctly and handles various error scenarios."""
+ token = {"type": "bearer"}
+ token_json = json.dumps(token).encode("utf-8")
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ )
+ code = "code"
+ ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ kwargs = self.http_client.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ # Test error handling
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad Request",
+ body=b'{"error": "foo", "error_description": "bar"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "foo")
+ self.assertEqual(exc.exception.error_description, "bar")
+
+ # Internal server error with no JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # Internal server error with JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b'{"error": "internal_server_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "internal_server_error")
+
+ # 4xx error without "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # 2xx error with "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "some_error")
diff --git a/tox.ini b/tox.ini
index c699f3e46..ad1902d47 100644
--- a/tox.ini
+++ b/tox.ini
@@ -185,6 +185,7 @@ commands = mypy \
synapse/handlers/auth.py \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
+ synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \