Clean up caching/locking of OIDC metadata load (#9362)

Ensure that we lock correctly to prevent multiple concurrent metadata load
requests, and generally clean up the way we construct the metadata cache.
This commit is contained in:
Richard van der Hoff 2021-02-16 16:27:38 +00:00 committed by GitHub
parent 0ad087273c
commit 3b754aea27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 389 additions and 62 deletions

View file

@ -41,6 +41,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -245,6 +246,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = provider.scopes
@ -253,14 +255,16 @@ class OidcProvider:
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = provider.discover
# cache of metadata for the identity provider (endpoint uris, mostly). This is
# loaded on-demand from the discovery endpoint (if discovery is enabled), with
# possible overrides from the config. Access via `load_metadata`.
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
# cache of JWKs used by the identity provider to sign tokens. Loaded on demand
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
@ -286,7 +290,7 @@ class OidcProvider:
self._sso_handler.register_identity_provider(self)
def _validate_metadata(self):
def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
@ -305,7 +309,6 @@ class OidcProvider:
if self._skip_verification is True:
return
m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
@ -340,11 +343,7 @@ class OidcProvider:
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
if m.get("jwks") is None:
if m.get("jwks_uri") is not None:
m.validate_jwks_uri()
else:
raise ValueError('"jwks_uri" must be set')
m.validate_jwks_uri()
@property
def _uses_userinfo(self) -> bool:
@ -361,11 +360,15 @@ class OidcProvider:
or self._user_profile_method == "userinfo_endpoint"
)
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata.
The values metadatas are discovered if ``oidc_config.discovery`` is
``True`` and then cached.
If this is the first call, the metadata is built from the config and from the
metadata discovery endpoint (if enabled), and then validated. If the metadata
is successfully validated, it is then cached for future use.
Args:
force: If true, any cached metadata is discarded to force a reload.
Raises:
ValueError: if something in the provider is not valid
@ -373,18 +376,32 @@ class OidcProvider:
Returns:
The provider's metadata.
"""
# If we are using the OpenID Discovery documents, it needs to be loaded once
# FIXME: should there be a lock here?
if self._provider_needs_discovery:
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
if force:
# reset the cached call to ensure we get a new result
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
return await self._provider_metadata.get()
async def _load_metadata(self) -> OpenIDProviderMetadata:
# init the metadata from our config
metadata = OpenIDProviderMetadata(
issuer=self._config.issuer,
authorization_endpoint=self._config.authorization_endpoint,
token_endpoint=self._config.token_endpoint,
userinfo_endpoint=self._config.userinfo_endpoint,
jwks_uri=self._config.jwks_uri,
)
# load any data from the discovery endpoint, if enabled
if self._config.discover:
url = get_well_known_url(self._config.issuer, external=True)
metadata_response = await self._http_client.get_json(url)
# TODO: maybe update the other way around to let user override some values?
self._provider_metadata.update(metadata_response)
self._provider_needs_discovery = False
metadata.update(metadata_response)
self._validate_metadata()
self._validate_metadata(metadata)
return self._provider_metadata
return metadata
async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
@ -414,27 +431,27 @@ class OidcProvider:
]
}
"""
if force:
# reset the cached call to ensure we get a new result
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
# First check if the JWKS are loaded in the provider metadata.
# It can happen either if the provider gives its JWKS in the discovery
# document directly or if it was already loaded once.
metadata = await self.load_metadata()
jwk_set = metadata.get("jwks")
if jwk_set is not None and not force:
return jwk_set
# Loading the JWKS using the `jwks_uri` metadata
# Load the JWKS using the `jwks_uri` metadata.
uri = metadata.get("jwks_uri")
if not uri:
# this should be unreachable: load_metadata validates that
# there is a jwks_uri in the metadata if _uses_userinfo is unset
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._http_client.get_json(uri)
# Caching the JWKS in the provider's metadata
self._provider_metadata["jwks"] = jwk_set
return jwk_set
async def _exchange_code(self, code: str) -> Token: