Refactor OIDC tests to better mimic an actual OIDC provider. (#13910)

This implements a fake OIDC server, which intercepts calls to the HTTP client.
Improves accuracy of tests by covering more internal methods.

One particular example was the ID token validation, which previously mocked.

This uncovered an incorrect dependency: Synapse actually requires at least
authlib 0.15.1, not 0.14.0.
This commit is contained in:
Quentin Gliech 2022-10-25 16:25:02 +02:00 committed by GitHub
parent 2d0ba3f89a
commit 9192d74b0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 747 additions and 460 deletions

1
changelog.d/13910.misc Normal file
View File

@ -0,0 +1 @@
Refactor OIDC tests to better mimic an actual OIDC provider.

View File

@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py
psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true }
psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true }
pysaml2 = { version = ">=4.5.0", optional = true }
authlib = { version = ">=0.14.0", optional = true }
authlib = { version = ">=0.15.1", optional = true }
# systemd-python is necessary for logging to the systemd journal via
# `systemd.journal.JournalHandler`, as is documented in
# `contrib/systemd/log_config.yaml`.

View File

@ -275,6 +275,7 @@ class OidcProvider:
provider: OidcProviderConfig,
):
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._macaroon_generaton = macaroon_generator
@ -673,6 +674,13 @@ class OidcProvider:
Returns:
The decoded claims in the ID token.
"""
id_token = token.get("id_token")
logger.debug("Attempting to decode JWT id_token %r", id_token)
# That has been theoritically been checked by the caller, so even though
# assertion are not enabled in production, it is mainly here to appease mypy
assert id_token is not None
metadata = await self.load_metadata()
claims_params = {
"nonce": nonce,
@ -688,9 +696,6 @@ class OidcProvider:
claim_options = {"iss": {"values": [metadata["issuer"]]}}
id_token = token["id_token"]
logger.debug("Attempting to decode JWT id_token %r", id_token)
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
@ -715,7 +720,9 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew
claims.validate(
now=self._clock.time(), leeway=120
) # allows 2 min of clock skew
return claims

View File

@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest import mock
import twisted.web.client
from twisted.internet import defer
from twisted.internet.protocol import Protocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import RoomVersions
@ -26,10 +23,9 @@ from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.test_utils import event_injection
from tests.test_utils import FakeResponse, event_injection
from tests.unittest import FederatingHomeserverTestCase
@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
# mock up the response, and have the agent return it
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response(
{
FakeResponse.json(
payload={
"pdus": [
create_event_dict,
member_event_dict,
@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
# mock up the response, and have the agent return it
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response(
{
FakeResponse.json(
payload={
"origin": "yet.another.server",
"origin_server_ts": 900,
"pdus": [
@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
# We expect an outbound request to /backfill, so stub that out
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response(
{
FakeResponse.json(
payload={
"origin": "yet.another.server",
"origin_server_ts": 900,
# Mimic the other server returning our new `pulled_event`
@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase):
# This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
# other from "yet.another.server"
self.assertEqual(backfill_num_attempts, 2)
def _mock_response(resp: JsonDict):
body = json.dumps(resp).encode("utf-8")
def deliver_body(p: Protocol):
p.dataReceived(body)
p.connectionLost(Failure(twisted.web.client.ResponseDone()))
response = mock.Mock(
code=200,
phrase=b"OK",
headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
length=len(body),
deliverBody=deliver_body,
)
mock.seal(response)
return response

File diff suppressed because it is too large Load Diff

View File

@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
* checking that the original operation succeeds
"""
fake_oidc_server = self.helper.fake_oidc_server()
# log the user in
remote_user_id = UserID.from_string(self.user).localpart
login_resp = self.helper.login_via_oidc(remote_user_id)
login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
self.assertEqual(login_resp["user_id"], self.user)
# initiate a UI Auth process by attempting to delete the device
@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
# run the UIA-via-SSO flow
session_id = channel.json_body["session"]
channel = self.helper.auth_via_oidc(
{"sub": remote_user_id}, ui_auth_session_id=session_id
channel, _ = self.helper.auth_via_oidc(
fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
)
# that should serve a confirmation page
@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self) -> None:
login_resp = self.helper.login_via_oidc("username")
fake_oidc_server = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
fake_oidc_server = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(
fake_oidc_server, UserID.from_string(self.user).localpart
)
self.assertEqual(login_resp["user_id"], self.user)
channel = self.delete_device(
@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
fake_oidc_server = self.helper.fake_oidc_server()
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
login_resp, _ = self.helper.login_via_oidc(
fake_oidc_server, UserID.from_string(self.user).localpart
)
self.assertEqual(login_resp["user_id"], self.user)
# start a UI Auth flow by attempting to delete a device
@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
session_id = channel.json_body["session"]
# do the OIDC auth, but auth as the wrong user
channel = self.helper.auth_via_oidc(
{"sub": "wrong_user"}, ui_auth_session_id=session_id
channel, _ = self.helper.auth_via_oidc(
fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
)
# that should return a failure message
@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""Tests that if we register a user via SSO while requiring approval for new
accounts, we still raise the correct error before logging the user in.
"""
login_resp = self.helper.login_via_oidc("username", expected_status=403)
fake_oidc_server = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(
fake_oidc_server, "username", expected_status=403
)
self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
self.assertEqual(

View File

@ -36,7 +36,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
@ -612,6 +612,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
fake_oidc_server = self.helper.fake_oidc_server()
with fake_oidc_server.patch_homeserver(hs=self.hs):
# pick the default OIDC provider
channel = self.make_request(
"GET",
@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
TEST_CLIENT_REDIRECT_URL,
)
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
channel, _ = self.helper.complete_oidc_auth(
fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
)
# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
@ -693,6 +698,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
fake_oidc_server = self.helper.fake_oidc_server()
with fake_oidc_server.patch_homeserver(hs=self.hs):
channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
"""Send a request to /_matrix/client/r0/login/sso/redirect
@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase):
def test_username_picker(self) -> None:
"""Test the happy path of a username picker flow."""
fake_oidc_server = self.helper.fake_oidc_server()
# do the start of the login flow
channel = self.helper.auth_via_oidc(
{"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
channel, _ = self.helper.auth_via_oidc(
fake_oidc_server,
{"sub": "tester", "displayname": "Jonny"},
TEST_CLIENT_REDIRECT_URL,
)
# that should redirect to the username picker

View File

@ -31,7 +31,6 @@ from typing import (
Tuple,
overload,
)
from unittest.mock import patch
from urllib.parse import urlencode
import attr
@ -46,8 +45,19 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
# an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_ISSUER = "https://issuer.test/"
TEST_OIDC_CONFIG = {
"enabled": True,
"issuer": TEST_OIDC_ISSUER,
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["openid"],
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}
@attr.s(auto_attribs=True)
@ -543,12 +553,28 @@ class RestHelper:
return channel.json_body
def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
"""Create a ``FakeOidcServer``.
This can be used in conjuction with ``login_via_oidc``::
fake_oidc_server = self.helper.fake_oidc_server()
login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
"""
return FakeOidcServer(
clock=self.hs.get_clock(),
issuer=issuer,
)
def login_via_oidc(
self,
fake_server: FakeOidcServer,
remote_user_id: str,
with_sid: bool = False,
expected_status: int = 200,
) -> JsonDict:
"""Log in via OIDC
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
Returns the result of the final token login.
@ -560,7 +586,10 @@ class RestHelper:
the normal places.
"""
client_redirect_url = "https://x"
channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
userinfo = {"sub": remote_user_id}
channel, grant = self.auth_via_oidc(
fake_server, userinfo, client_redirect_url, with_sid=with_sid
)
# expect a confirmation page
assert channel.code == HTTPStatus.OK, channel.result
@ -585,14 +614,16 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
return channel.json_body
return channel.json_body, grant
def auth_via_oidc(
self,
fake_server: FakeOidcServer,
user_info_dict: JsonDict,
client_redirect_url: Optional[str] = None,
ui_auth_session_id: Optional[str] = None,
) -> FakeChannel:
with_sid: bool = False,
) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Perform an OIDC authentication flow via a mock OIDC provider.
This can be used for either login or user-interactive auth.
@ -616,6 +647,7 @@ class RestHelper:
the login redirect endpoint
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
of the UI auth.
with_sid: if True, generates a random `sid` (OIDC session ID)
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
@ -625,6 +657,7 @@ class RestHelper:
cookies: Dict[str, str] = {}
with fake_server.patch_homeserver(hs=self.hs):
# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
# can't set the client redirect url for UI Auth
@ -640,17 +673,21 @@ class RestHelper:
# that synapse passes to the client.
oauth_uri_path, _ = oauth_uri.split("?", 1)
assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
assert oauth_uri_path == fake_server.authorization_endpoint, (
"unexpected SSO URI " + oauth_uri_path
)
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
return self.complete_oidc_auth(
fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
)
def complete_oidc_auth(
self,
fake_serer: FakeOidcServer,
oauth_uri: str,
cookies: Mapping[str, str],
user_info_dict: JsonDict,
) -> FakeChannel:
with_sid: bool = False,
) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Mock out an OIDC authentication flow
Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@ -661,50 +698,37 @@ class RestHelper:
Requires the OIDC callback resource to be mounted at the normal place.
Args:
fake_server: the fake OIDC server with which the auth should be done
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.
with_sid: if True, generates a random `sid` (OIDC session ID)
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
code, grant = fake_serer.start_authorization(
scope=params["scope"][0],
userinfo=user_info_dict,
client_id=params["client_id"][0],
redirect_uri=params["redirect_uri"][0],
nonce=params["nonce"][0],
with_sid=with_sid,
)
state = params["state"][0]
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
urllib.parse.urlencode({"state": state, "code": code}),
)
# before we hit the callback uri, stub out some methods in the http client so
# that we don't have to handle full HTTPS requests.
# (expected url, json response) pairs, in the order we expect them.
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
(TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
(TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
]
async def mock_req(
method: str,
uri: str,
data: Optional[dict] = None,
headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
code=HTTPStatus.OK,
phrase=b"OK",
body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
self.hs.get_reactor(),
@ -715,7 +739,7 @@ class RestHelper:
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
return channel
return channel, grant
def initiate_sso_login(
self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
@ -806,21 +830,3 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0]
return oauth_uri
# an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
"issuer": "https://issuer.test",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
"userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}

View File

@ -15,17 +15,24 @@
"""
Utilities for running the unit tests
"""
import json
import sys
import warnings
from asyncio import Future
from binascii import unhexlify
from typing import Awaitable, Callable, TypeVar
from typing import Awaitable, Callable, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from synapse.types import JsonDict
TV = TypeVar("TV")
@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock:
return Mock(side_effect=cb)
@attr.s
class FakeResponse:
# Type ignore: it does not fully implement IResponse, but is good enough for tests
@zope.interface.implementer(IResponse)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FakeResponse: # type: ignore[misc]
"""A fake twisted.web.IResponse object
there is a similar class at treq.test.test_response, but it lacks a `phrase`
attribute, and didn't support deliverBody until recently.
"""
# HTTP response code
code = attr.ib(type=int)
version: Tuple[bytes, int, int] = (b"HTTP", 1, 1)
# HTTP response phrase (eg b'OK' for a 200)
phrase = attr.ib(type=bytes)
# HTTP response code
code: int = 200
# body of the response
body = attr.ib(type=bytes)
body: bytes = b""
headers: Headers = attr.Factory(Headers)
@property
def phrase(self):
return RESPONSES.get(self.code, b"Unknown Status")
@property
def length(self):
return len(self.body)
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
@classmethod
def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
headers = Headers({"Content-Type": ["application/json"]})
body = json.dumps(payload).encode("utf-8")
return cls(code=code, body=body, headers=headers)
# A small image used in some tests.
#

325
tests/test_utils/oidc.py Normal file
View File

@ -0,0 +1,325 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
import attr
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from synapse.server import HomeServer
from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FakeAuthorizationGrant:
userinfo: dict
client_id: str
redirect_uri: str
scope: str
nonce: Optional[str]
sid: Optional[str]
class FakeOidcServer:
"""A fake OpenID Connect Provider."""
# All methods here are mocks, so we can track when they are called, and override
# their values
request: Mock
get_jwks_handler: Mock
get_metadata_handler: Mock
get_userinfo_handler: Mock
post_token_handler: Mock
def __init__(self, clock: Clock, issuer: str):
from authlib.jose import ECKey, KeySet
self._clock = clock
self.issuer = issuer
self.request = Mock(side_effect=self._request)
self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler)
self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler)
self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler)
self.post_token_handler = Mock(side_effect=self._post_token_handler)
# A code -> grant mapping
self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {}
# An access token -> grant mapping
self._sessions: Dict[str, FakeAuthorizationGrant] = {}
# We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for
# signing JWTs. ECDSA keys are really quick to generate compared to RSA.
self._key = ECKey.generate_key(crv="P-256", is_private=True)
self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))])
self._id_token_overrides: Dict[str, Any] = {}
def reset_mocks(self):
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
def patch_homeserver(self, hs: HomeServer):
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
OIDC provider, e.g.::
fake_oidc_server = self.helper.fake_oidc_server()
with fake_oidc_server.patch_homeserver(hs):
self.make_request("GET", "/_matrix/client/r0/login/sso/redirect")
"""
return patch.object(hs.get_proxied_http_client(), "request", self.request)
@property
def authorization_endpoint(self) -> str:
return self.issuer + "authorize"
@property
def token_endpoint(self) -> str:
return self.issuer + "token"
@property
def userinfo_endpoint(self) -> str:
return self.issuer + "userinfo"
@property
def metadata_endpoint(self) -> str:
return self.issuer + ".well-known/openid-configuration"
@property
def jwks_uri(self) -> str:
return self.issuer + "jwks"
def get_metadata(self) -> dict:
return {
"issuer": self.issuer,
"authorization_endpoint": self.authorization_endpoint,
"token_endpoint": self.token_endpoint,
"jwks_uri": self.jwks_uri,
"userinfo_endpoint": self.userinfo_endpoint,
"response_types_supported": ["code"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["ES256"],
}
def get_jwks(self) -> dict:
return self._jwks.as_dict()
def get_userinfo(self, access_token: str) -> Optional[dict]:
"""Given an access token, get the userinfo of the associated session."""
session = self._sessions.get(access_token, None)
if session is None:
return None
return session.userinfo
def _sign(self, payload: dict) -> str:
from authlib.jose import JsonWebSignature
jws = JsonWebSignature()
kid = self.get_jwks()["keys"][0]["kid"]
protected = {"alg": "ES256", "kid": kid}
json_payload = json.dumps(payload)
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
now = self._clock.time()
id_token = {
**grant.userinfo,
"iss": self.issuer,
"aud": grant.client_id,
"iat": now,
"nbf": now,
"exp": now + 600,
}
if grant.nonce is not None:
id_token["nonce"] = grant.nonce
if grant.sid is not None:
id_token["sid"] = grant.sid
id_token.update(self._id_token_overrides)
return self._sign(id_token)
def id_token_override(self, overrides: dict):
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
def start_authorization(
self,
client_id: str,
scope: str,
redirect_uri: str,
userinfo: dict,
nonce: Optional[str] = None,
with_sid: bool = False,
) -> Tuple[str, FakeAuthorizationGrant]:
"""Start an authorization request, and get back the code to use on the authorization endpoint."""
code = random_string(10)
sid = None
if with_sid:
sid = random_string(10)
grant = FakeAuthorizationGrant(
userinfo=userinfo,
scope=scope,
redirect_uri=redirect_uri,
nonce=nonce,
client_id=client_id,
sid=sid,
)
self._authorization_grants[code] = grant
return code, grant
def exchange_code(self, code: str) -> Optional[Dict[str, Any]]:
grant = self._authorization_grants.pop(code, None)
if grant is None:
return None
access_token = random_string(10)
self._sessions[access_token] = grant
token = {
"token_type": "Bearer",
"access_token": access_token,
"expires_in": 3600,
"scope": grant.scope,
}
if "openid" in grant.scope:
token["id_token"] = self.generate_id_token(grant)
return dict(token)
def buggy_endpoint(
self,
*,
jwks: bool = False,
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
):
"""A context which makes a set of endpoints return a 500 error.
Args:
jwks: If True, makes the JWKS endpoint return a 500 error.
metadata: If True, makes the OIDC Discovery endpoint return a 500 error.
token: If True, makes the token endpoint return a 500 error.
userinfo: If True, makes the userinfo endpoint return a 500 error.
"""
buggy = FakeResponse(code=500, body=b"Internal server error")
patches = {}
if jwks:
patches["get_jwks_handler"] = Mock(return_value=buggy)
if metadata:
patches["get_metadata_handler"] = Mock(return_value=buggy)
if token:
patches["post_token_handler"] = Mock(return_value=buggy)
if userinfo:
patches["get_userinfo_handler"] = Mock(return_value=buggy)
return patch.multiple(self, **patches)
async def _request(
self,
method: str,
uri: str,
data: Optional[bytes] = None,
headers: Optional[Headers] = None,
) -> IResponse:
"""The override of the SimpleHttpClient#request() method"""
access_token: Optional[str] = None
if headers is None:
headers = Headers()
# Try to find the access token in the headers if any
auth_headers = headers.getRawHeaders(b"Authorization")
if auth_headers:
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
access_token = parts[1].decode("ascii")
if method == "POST":
# If the method is POST, assume it has an url-encoded body
if data is None or headers.getRawHeaders(b"Content-Type") != [
b"application/x-www-form-urlencoded"
]:
return FakeResponse.json(code=400, payload={"error": "invalid_request"})
params = parse_qs(data.decode("utf-8"))
if uri == self.token_endpoint:
# Even though this endpoint should be protected, this does not check
# for client authentication. We're not checking it for simplicity,
# and because client authentication is tested in other standalone tests.
return self.post_token_handler(params)
elif method == "GET":
if uri == self.jwks_uri:
return self.get_jwks_handler()
elif uri == self.metadata_endpoint:
return self.get_metadata_handler()
elif uri == self.userinfo_endpoint:
return self.get_userinfo_handler(access_token=access_token)
return FakeResponse(code=404, body=b"404 not found")
# Request handlers
def _get_jwks_handler(self) -> IResponse:
"""Handles requests to the JWKS URI."""
return FakeResponse.json(payload=self.get_jwks())
def _get_metadata_handler(self) -> IResponse:
"""Handles requests to the OIDC well-known document."""
return FakeResponse.json(payload=self.get_metadata())
def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse:
"""Handles requests to the userinfo endpoint."""
if access_token is None:
return FakeResponse(code=401)
user_info = self.get_userinfo(access_token)
if user_info is None:
return FakeResponse(code=401)
return FakeResponse.json(payload=user_info)
def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse:
"""Handles requests to the token endpoint."""
code = params.get("code", [])
if len(code) != 1:
return FakeResponse.json(code=400, payload={"error": "invalid_request"})
grant = self.exchange_code(code=code[0])
if grant is None:
return FakeResponse.json(code=400, payload={"error": "invalid_grant"})
return FakeResponse.json(payload=grant)