Merge pull request #9036 from matrix-org/rav/multi_idp/tests

Add tests for the IdP picker
This commit is contained in:
Richard van der Hoff 2021-01-08 14:24:41 +00:00 committed by GitHub
commit 12f79da587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 253 additions and 38 deletions

1
changelog.d/9036.feature Normal file
View File

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View File

@ -103,6 +103,7 @@ files =
tests/replication, tests/replication,
tests/test_utils, tests/test_utils,
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py, tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py

View File

@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet):
# register themselves with the main SSOHandler. # register themselves with the main SSOHandler.
if hs.config.cas_enabled: if hs.config.cas_enabled:
hs.get_cas_handler() hs.get_cas_handler()
elif hs.config.saml2_enabled: if hs.config.saml2_enabled:
hs.get_saml_handler() hs.get_saml_handler()
elif hs.config.oidc_enabled: if hs.config.oidc_enabled:
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()

View File

@ -1,22 +1,67 @@
import json # -*- coding: utf-8 -*-
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
import urllib.parse import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from mock import Mock from mock import Mock
try: import pymacaroons
import jwt
except ImportError: from twisted.web.resource import Resource
jwt = None
import synapse.rest.admin import synapse.rest.admin
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.rest.synapse.client.pick_idp import PickIdpResource
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.unittest import override_config, skip_unless
try:
import jwt
HAS_JWT = True
except ImportError:
HAS_JWT = False
# public_base_url used in some tests
BASE_URL = "https://synapse/"
# CAS server used in some tests
CAS_SERVER = "https://fake.test"
# just enough to tell pysaml2 where to redirect to
SAML_SERVER = "https://test.saml.server/idp/sso"
TEST_SAML_METADATA = """
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
""" % {
"SAML_SERVER": SAML_SERVER,
}
LOGIN_URL = b"/_matrix/client/r0/login" LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami" TEST_URL = b"/_matrix/client/r0/account/whoami"
@ -314,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
"""Tests for homeservers with multiple SSO providers enabled"""
servlets = [
login.register_servlets,
]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
config["cas_config"] = {
"enabled": True,
"server_url": CAS_SERVER,
"service_url": "https://matrix.goodserver.com:8448",
}
config["saml2_config"] = {
"sp_config": {
"metadata": {"inline": [TEST_SAML_METADATA]},
# use the XMLSecurity backend to avoid relying on xmlsec1
"crypto_backend": "XMLSecurity",
},
}
config["oidc_config"] = TEST_OIDC_CONFIG
return config
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
return d
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
client_redirect_url = "https://x?<abc>"
# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser):
def __init__(self):
super().__init__()
# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]
def error(_, message):
self.fail(message)
p = FormPageParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()
self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
cas_uri = channel.headers.getRawHeaders("Location")[0]
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
# it should redirect us to the login page of the cas server
self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
# check that the redirectUrl is correctly encoded in the service param - ie, the
# place that CAS will redirect to
cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
saml_uri = channel.headers.getRawHeaders("Location")[0]
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
# it should redirect us to the login page of the SAML server
self.assertEqual(saml_uri_path, SAML_SERVER)
# the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url)
def test_multi_sso_redirect_to_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
# ... and should have set a cookie including the redirect url
cookies = dict(
h.split(";")[0].split("=", maxsplit=1)
for h in channel.headers.getRawHeaders("Set-Cookie")
)
oidc_session_cookie = cookies["oidc_session"]
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"),
client_redirect_url,
)
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
"GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
class CASTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
@ -327,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase):
config = self.default_config() config = self.default_config()
config["cas_config"] = { config["cas_config"] = {
"enabled": True, "enabled": True,
"server_url": "https://fake.test", "server_url": CAS_SERVER,
"service_url": "https://matrix.goodserver.com:8448", "service_url": "https://matrix.goodserver.com:8448",
} }
@ -413,8 +636,7 @@ class CASTestCase(unittest.HomeserverTestCase):
} }
) )
def test_cas_redirect_whitelisted(self): def test_cas_redirect_whitelisted(self):
"""Tests that the SSO login flow serves a redirect to a whitelisted url """Tests that the SSO login flow serves a redirect to a whitelisted url"""
"""
self._test_redirect("https://legit-site.com/") self._test_redirect("https://legit-site.com/")
@override_config({"public_baseurl": "https://example.com"}) @override_config({"public_baseurl": "https://example.com"})
@ -462,10 +684,8 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertIn(b"SSO account deactivated", channel.result["body"]) self.assertIn(b"SSO account deactivated", channel.result["body"])
@skip_unless(HAS_JWT, "requires jwt")
class JWTTestCase(unittest.HomeserverTestCase): class JWTTestCase(unittest.HomeserverTestCase):
if not jwt:
skip = "requires jwt"
servlets = [ servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
@ -481,17 +701,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs return self.hs
def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(token, secret, self.jwt_algorithm) result = jwt.encode(
payload, secret, self.jwt_algorithm
) # type: Union[str, bytes]
if isinstance(result, bytes): if isinstance(result, bytes):
return result.decode("ascii") return result.decode("ascii")
return result return result
def jwt_login(self, *args): def jwt_login(self, *args):
params = json.dumps( params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
return channel return channel
@ -623,7 +843,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
) )
def test_login_no_token(self): def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"}) params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -633,10 +853,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key. # signed by the private key.
@skip_unless(HAS_JWT, "requires jwt")
class JWTPubKeyTestCase(unittest.HomeserverTestCase): class JWTPubKeyTestCase(unittest.HomeserverTestCase):
if not jwt:
skip = "requires jwt"
servlets = [ servlets = [
login.register_servlets, login.register_servlets,
] ]
@ -693,17 +911,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256" self.hs.config.jwt_algorithm = "RS256"
return self.hs return self.hs
def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(token, secret, "RS256") result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
if isinstance(result, bytes): if isinstance(result, bytes):
return result.decode("ascii") return result.decode("ascii")
return result return result
def jwt_login(self, *args): def jwt_login(self, *args):
params = json.dumps( params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
return channel return channel
@ -773,8 +989,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_login_appservice_user(self): def test_login_appservice_user(self):
"""Test that an appservice user can use /login """Test that an appservice user can use /login"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -788,8 +1003,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self): def test_login_appservice_user_bot(self):
"""Test that the appservice bot can use /login """Test that the appservice bot can use /login"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -803,8 +1017,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self): def test_login_appservice_wrong_user(self):
"""Test that non-as users cannot login with the as token """Test that non-as users cannot login with the as token"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -818,8 +1031,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self): def test_login_appservice_wrong_as(self):
"""Test that as users cannot login with wrong as token """Test that as users cannot login with wrong as token"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {

View File

@ -444,6 +444,7 @@ class RestHelper:
# an 'oidc_config' suitable for login_via_oidc. # an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_CONFIG = { TEST_OIDC_CONFIG = {
"enabled": True, "enabled": True,
"discover": False, "discover": False,
@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = {
"client_id": "test-client-id", "client_id": "test-client-id",
"client_secret": "test-client-secret", "client_secret": "test-client-secret",
"scopes": ["profile"], "scopes": ["profile"],
"authorization_endpoint": "https://z", "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": "https://issuer.test/token", "token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo", "userinfo_endpoint": "https://issuer.test/userinfo",
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},