Improve error checking for OIDC/SAML mapping providers (#8774)

Checks that the localpart returned by mapping providers for SAML and
OIDC are valid before registering new users.

Extends the OIDC tests for existing users and invalid data.
This commit is contained in:
Patrick Cloke 2020-11-19 14:25:17 -05:00 committed by GitHub
parent 53a6f5ddf0
commit 79bfe966e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 29 deletions

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from urllib.parse import parse_qs, urlparse
@ -24,12 +23,8 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from synapse.handlers.oidc_handler import (
MappingException,
OidcError,
OidcHandler,
OidcMappingProvider,
)
from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.unittest import HomeserverTestCase, override_config
@ -132,14 +127,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
oidc_config["issuer"] = ISSUER
oidc_config["scopes"] = SCOPES
oidc_config["user_mapping_provider"] = {
"module": __name__ + ".TestMappingProvider",
oidc_config = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@ -705,13 +699,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
user4 = UserID.from_string("@test_user_4:test")
user = UserID.from_string("@test_user:test")
self.get_success(
store.register_user(user_id=user4.to_string(), password_hash=None)
store.register_user(user_id=user.to_string(), password_hash=None)
)
userinfo = {
"sub": "test4",
"username": "test_user_4",
"sub": "test",
"username": "test_user",
}
token = {}
mxid = self.get_success(
@ -719,4 +713,59 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user_4:test")
self.assertEqual(mxid, "@test_user:test")
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
self.get_success(
store.register_user(user_id=user2.to_string(), password_hash=None)
)
user2_caps = UserID.from_string("@test_USER_2:test")
self.get_success(
store.register_user(user_id=user2_caps.to_string(), password_hash=None)
)
# Attempting to login without matching a name exactly is an error.
userinfo = {
"sub": "test2",
"username": "TEST_USER_2",
}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertTrue(
str(e.value).startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
# Logging in when matching a name exactly should work.
user2 = UserID.from_string("@TEST_USER_2:test")
self.get_success(
store.register_user(user_id=user2.to_string(), password_hash=None)
)
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
userinfo = {
"sub": "test2",
"username": "föö",
}
token = {}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")