Clean up caching/locking of OIDC metadata load (#9362)

Ensure that we lock correctly to prevent multiple concurrent metadata load
requests, and generally clean up the way we construct the metadata cache.
This commit is contained in:
Richard van der Hoff 2021-02-16 16:27:38 +00:00 committed by GitHub
parent 0ad087273c
commit 3b754aea27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 389 additions and 62 deletions

1
changelog.d/9362.misc Normal file
View File

@ -0,0 +1 @@
Clean up the code to load the metadata for OpenID Connect identity providers.

View File

@ -41,6 +41,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -245,6 +246,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = provider.scopes
@ -253,14 +255,16 @@ class OidcProvider:
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = provider.discover
# cache of metadata for the identity provider (endpoint uris, mostly). This is
# loaded on-demand from the discovery endpoint (if discovery is enabled), with
# possible overrides from the config. Access via `load_metadata`.
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
# cache of JWKs used by the identity provider to sign tokens. Loaded on demand
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
@ -286,7 +290,7 @@ class OidcProvider:
self._sso_handler.register_identity_provider(self)
def _validate_metadata(self):
def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
@ -305,7 +309,6 @@ class OidcProvider:
if self._skip_verification is True:
return
m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
@ -340,11 +343,7 @@ class OidcProvider:
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
if m.get("jwks") is None:
if m.get("jwks_uri") is not None:
m.validate_jwks_uri()
else:
raise ValueError('"jwks_uri" must be set')
m.validate_jwks_uri()
@property
def _uses_userinfo(self) -> bool:
@ -361,11 +360,15 @@ class OidcProvider:
or self._user_profile_method == "userinfo_endpoint"
)
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata.
The values metadatas are discovered if ``oidc_config.discovery`` is
``True`` and then cached.
If this is the first call, the metadata is built from the config and from the
metadata discovery endpoint (if enabled), and then validated. If the metadata
is successfully validated, it is then cached for future use.
Args:
force: If true, any cached metadata is discarded to force a reload.
Raises:
ValueError: if something in the provider is not valid
@ -373,18 +376,32 @@ class OidcProvider:
Returns:
The provider's metadata.
"""
# If we are using the OpenID Discovery documents, it needs to be loaded once
# FIXME: should there be a lock here?
if self._provider_needs_discovery:
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
if force:
# reset the cached call to ensure we get a new result
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
return await self._provider_metadata.get()
async def _load_metadata(self) -> OpenIDProviderMetadata:
# init the metadata from our config
metadata = OpenIDProviderMetadata(
issuer=self._config.issuer,
authorization_endpoint=self._config.authorization_endpoint,
token_endpoint=self._config.token_endpoint,
userinfo_endpoint=self._config.userinfo_endpoint,
jwks_uri=self._config.jwks_uri,
)
# load any data from the discovery endpoint, if enabled
if self._config.discover:
url = get_well_known_url(self._config.issuer, external=True)
metadata_response = await self._http_client.get_json(url)
# TODO: maybe update the other way around to let user override some values?
self._provider_metadata.update(metadata_response)
self._provider_needs_discovery = False
metadata.update(metadata_response)
self._validate_metadata()
self._validate_metadata(metadata)
return self._provider_metadata
return metadata
async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
@ -414,27 +431,27 @@ class OidcProvider:
]
}
"""
if force:
# reset the cached call to ensure we get a new result
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
# First check if the JWKS are loaded in the provider metadata.
# It can happen either if the provider gives its JWKS in the discovery
# document directly or if it was already loaded once.
metadata = await self.load_metadata()
jwk_set = metadata.get("jwks")
if jwk_set is not None and not force:
return jwk_set
# Loading the JWKS using the `jwks_uri` metadata
# Load the JWKS using the `jwks_uri` metadata.
uri = metadata.get("jwks_uri")
if not uri:
# this should be unreachable: load_metadata validates that
# there is a jwks_uri in the metadata if _uses_userinfo is unset
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._http_client.get_json(uri)
# Caching the JWKS in the provider's metadata
self._provider_metadata["jwks"] = jwk_set
return jwk_set
async def _exchange_code(self, code: str) -> Token:

View File

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
TV = TypeVar("TV")
class CachedCall(Generic[TV]):
"""A wrapper for asynchronous calls whose results should be shared
This is useful for wrapping asynchronous functions, where there might be multiple
callers, but we only want to call the underlying function once (and have the result
returned to all callers).
Similar results can be achieved via a lock of some form, but that typically requires
more boilerplate (and ends up being less efficient).
Correctly handles Synapse logcontexts (logs and resource usage for the underlying
function are logged against the logcontext which is active when get() is first
called).
Example usage:
_cached_val = CachedCall(_load_prop)
async def handle_request() -> X:
# We can call this multiple times, but it will result in a single call to
# _load_prop().
return await _cached_val.get()
async def _load_prop() -> X:
await difficult_operation()
The implementation is deliberately single-shot (ie, once the call is initiated,
there is no way to ask for it to be run). This keeps the implementation and
semantics simple. If you want to make a new call, simply replace the whole
CachedCall object.
"""
__slots__ = ["_callable", "_deferred", "_result"]
def __init__(self, f: Callable[[], Awaitable[TV]]):
"""
Args:
f: The underlying function. Only one call to this function will be alive
at once (per instance of CachedCall)
"""
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
self._deferred = None # type: Optional[Deferred]
self._result = None # type: Union[None, Failure, TV]
async def get(self) -> TV:
"""Kick off the call if necessary, and return the result"""
# Fire off the callable now if this is our first time
if not self._deferred:
self._deferred = run_in_background(self._callable)
# we will never need the callable again, so make sure it can be GCed
self._callable = None
# once the deferred completes, store the result. We cannot simply leave the
# result in the deferred, since if it's a Failure, GCing the deferred
# would then log a critical error about unhandled Failures.
def got_result(r):
self._result = r
self._deferred.addBoth(got_result)
# TODO: consider cancellation semantics. Currently, if the call to get()
# is cancelled, the underlying call will continue (and any future calls
# will get the result/exception), which I think is *probably* ok, modulo
# the fact the underlying call may be logged to a cancelled logcontext,
# and any eventual exception may not be reported.
# we can now await the deferred, and once it completes, return the result.
await make_deferred_yieldable(self._deferred)
# I *think* this is the easiest way to correctly raise a Failure without having
# to gut-wrench into the implementation of Deferred.
d = Deferred()
d.callback(self._result)
return await d
class RetryOnExceptionCachedCall(Generic[TV]):
"""A wrapper around CachedCall which will retry the call if an exception is thrown
This is used in much the same way as CachedCall, but adds some extra functionality
so that if the underlying function throws an exception, then the next call to get()
will initiate another call to the underlying function. (Any calls to get() which
are already pending will raise the exception.)
"""
slots = ["_cachedcall"]
def __init__(self, f: Callable[[], Awaitable[TV]]):
async def _wrapper() -> TV:
try:
return await f()
except Exception:
# the call raised an exception: replace the underlying CachedCall to
# trigger another call next time get() is called
self._cachedcall = CachedCall(_wrapper)
raise
self._cachedcall = CachedCall(_wrapper)
async def get(self) -> TV:
return await self._cachedcall.get()

View File

@ -24,7 +24,7 @@ from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
from tests.test_utils import FakeResponse, simple_async_mock
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
try:
@ -131,7 +131,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
@ -151,7 +150,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
return patch.dict(self.provider._provider_metadata, values)
"""Modify the result that will be returned by the well-known query"""
async def patched_get_json(uri):
res = await get_json(uri)
if uri == WELL_KNOWN:
res.update(values)
return res
return patch.object(self.http_client, "get_json", patched_get_json)
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
@ -212,7 +219,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
original = self.provider.load_metadata
async def patched_load_metadata():
m = (await original()).copy()
m.update({"jwks_uri": None})
return m
with patch.object(self.provider, "load_metadata", patched_load_metadata):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
@ -222,55 +236,60 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.provider
def force_load_metadata():
async def force_load():
return await h.load_metadata(force=True)
return get_awaitable_result(force_load())
# Default test config does not throw
h._validate_metadata()
force_load_metadata()
with self.metadata_edit({"issuer": None}):
self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "http://insecure/"}):
self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"authorization_endpoint": None}):
self.assertRaisesRegex(
ValueError, "authorization_endpoint", h._validate_metadata
ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
self.assertRaisesRegex(
ValueError, "authorization_endpoint", h._validate_metadata
ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"token_endpoint": None}):
self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"jwks_uri": None}):
self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"response_types_supported": ["id_token"]}):
self.assertRaisesRegex(
ValueError, "response_types_supported", h._validate_metadata
ValueError, "response_types_supported", force_load_metadata
)
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
):
# should not throw, as client_secret_basic is the default auth method
h._validate_metadata()
force_load_metadata()
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_post"]}
@ -278,7 +297,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRaisesRegex(
ValueError,
"token_endpoint_auth_methods_supported",
h._validate_metadata,
force_load_metadata,
)
# Tests for configs that require the userinfo endpoint
@ -287,24 +306,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._user_profile_method = "userinfo_endpoint"
self.assertTrue(h._uses_userinfo)
# Revert the profile method and do not request the "openid" scope.
# Revert the profile method and do not request the "openid" scope: this should
# mean that we check for a userinfo endpoint
h._user_profile_method = "auto"
h._scopes = []
self.assertTrue(h._uses_userinfo)
self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
with self.metadata_edit({"userinfo_endpoint": None}):
self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
with self.metadata_edit(
{"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
):
# Shouldn't raise with a valid userinfo, even without
h._validate_metadata()
with self.metadata_edit({"jwks_uri": None}):
# Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata()
@override_config({"oidc_config": {"skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
self.provider._validate_metadata()
get_awaitable_result(self.provider.load_metadata())
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""

View File

@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.util.caches.cached_call import CachedCall, RetryOnExceptionCachedCall
from tests.test_utils import get_awaitable_result
from tests.unittest import TestCase
class CachedCallTestCase(TestCase):
def test_get(self):
"""
Happy-path test case: makes a couple of calls and makes sure they behave
correctly
"""
d = Deferred()
async def f():
return await d
slow_call = Mock(side_effect=f)
cached_call = CachedCall(slow_call)
# the mock should not yet have been called
slow_call.assert_not_called()
# now fire off a couple of calls
completed_results = []
async def r():
res = await cached_call.get()
completed_results.append(res)
r1 = defer.ensureDeferred(r())
r2 = defer.ensureDeferred(r())
# neither result should be complete yet
self.assertNoResult(r1)
self.assertNoResult(r2)
# and the mock should have been called *once*, with no params
slow_call.assert_called_once_with()
# allow the deferred to complete, which should complete both the pending results
d.callback(123)
self.assertEqual(completed_results, [123, 123])
self.successResultOf(r1)
self.successResultOf(r2)
# another call to the getter should complete immediately
slow_call.reset_mock()
r3 = get_awaitable_result(cached_call.get())
self.assertEqual(r3, 123)
slow_call.assert_not_called()
def test_fast_call(self):
"""
Test the behaviour when the underlying function completes immediately
"""
async def f():
return 12
fast_call = Mock(side_effect=f)
cached_call = CachedCall(fast_call)
# the mock should not yet have been called
fast_call.assert_not_called()
# run the call a couple of times, which should complete immediately
self.assertEqual(get_awaitable_result(cached_call.get()), 12)
self.assertEqual(get_awaitable_result(cached_call.get()), 12)
# the mock should have been called once
fast_call.assert_called_once_with()
class RetryOnExceptionCachedCallTestCase(TestCase):
def test_get(self):
# set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while)
d = Deferred()
async def f1():
await d
raise ValueError("moo")
slow_call = Mock(side_effect=f1)
cached_call = RetryOnExceptionCachedCall(slow_call)
# the mock should not yet have been called
slow_call.assert_not_called()
# now fire off a couple of calls
completed_results = []
async def r():
try:
await cached_call.get()
except Exception as e1:
completed_results.append(e1)
r1 = defer.ensureDeferred(r())
r2 = defer.ensureDeferred(r())
# neither result should be complete yet
self.assertNoResult(r1)
self.assertNoResult(r2)
# and the mock should have been called *once*, with no params
slow_call.assert_called_once_with()
# complete the deferred, which should make the pending calls fail
d.callback(0)
self.assertEqual(len(completed_results), 2)
for e in completed_results:
self.assertIsInstance(e, ValueError)
self.assertEqual(e.args, ("moo",))
# reset the mock to return a successful result, and make another pair of calls
# to the getter
d = Deferred()
async def f2():
return await d
slow_call.reset_mock()
slow_call.side_effect = f2
r3 = defer.ensureDeferred(cached_call.get())
r4 = defer.ensureDeferred(cached_call.get())
self.assertNoResult(r3)
self.assertNoResult(r4)
slow_call.assert_called_once_with()
# let that call complete, and check the results
d.callback(123)
self.assertEqual(self.successResultOf(r3), 123)
self.assertEqual(self.successResultOf(r4), 123)
# and now more calls to the getter should complete immediately
slow_call.reset_mock()
self.assertEqual(get_awaitable_result(cached_call.get()), 123)
slow_call.assert_not_called()