Preparatory refactoring of the OidcHandlerTestCase (#8911)

* Remove references to handler._auth_handler

(and replace them with hs.get_auth_handler)

* Factor out a utility function for building Requests

* Remove mocks of `OidcHandler._map_userinfo_to_user`

This method is going away, so mocking it out is no longer a valid approach.

Instead, we mock out lower-level methods (eg _remote_id_from_userinfo), or
simply allow the regular implementation to proceed and update the expectations
accordingly.

* Remove references to `OidcHandler._map_userinfo_to_user` from tests

This method is going away, so we can no longer use it as a test point. Instead
we build mock "callback" requests which we pass into `handle_oidc_callback`,
and verify correct behaviour by mocking out `AuthHandler.complete_sso_login`.
This commit is contained in:
Richard van der Hoff 2020-12-14 11:38:50 +00:00 committed by GitHub
parent f14428b25c
commit 895e04319b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 146 additions and 141 deletions

1
changelog.d/8911.feature Normal file
View File

@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

View File

@ -15,7 +15,7 @@
import json import json
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from mock import Mock, patch from mock import ANY, Mock, patch
import pymacaroons import pymacaroons
@ -82,7 +82,7 @@ class TestMappingProviderFailures(TestMappingProvider):
} }
def simple_async_mock(return_value=None, raises=None): def simple_async_mock(return_value=None, raises=None) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour # AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs): async def cb(*args, **kwargs):
if raises: if raises:
@ -160,6 +160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args[2], error_description) self.assertEqual(args[2], error_description)
# Reset the render_error mock # Reset the render_error mock
self.render_error.reset_mock() self.render_error.reset_mock()
return args
def test_config(self): def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
@ -374,26 +375,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
"id_token": "id_token", "id_token": "id_token",
"access_token": "access_token", "access_token": "access_token",
} }
username = "bar"
userinfo = { userinfo = {
"sub": "foo", "sub": "foo",
"preferred_username": "bar", "username": username,
} }
user_id = "@foo:domain.org" expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) auth_handler = self.hs.get_auth_handler()
self.handler._auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
code = "code" code = "code"
state = "state" state = "state"
@ -401,64 +393,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
user_agent = "Browser" user_agent = "Browser"
ip_address = "10.0.0.1" ip_address = "10.0.0.1"
request.getCookie.return_value = self.handler._generate_oidc_session_token( session = self.handler._generate_oidc_session_token(
state=state, state=state,
nonce=nonce, nonce=nonce,
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request(
request.args = {} code, state, session, user_agent=user_agent, ip_address=ip_address
request.args[b"code"] = [code.encode("utf-8")] )
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {}, expected_user_id, request, client_redirect_url, {},
) )
self.handler._exchange_code.assert_called_once_with(code) self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called() self.handler._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called() self.render_error.assert_not_called()
# Handle mapping errors # Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock( with patch.object(
raises=MappingException() self.handler,
) "_remote_id_from_userinfo",
self.get_success(self.handler.handle_oidc_callback(request)) new=Mock(side_effect=MappingException()),
self.assertRenderedError("mapping_error") ):
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
# Handle ID token errors # Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception()) self.handler._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token") self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock() self.handler._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock() self.handler._parse_id_token.reset_mock()
self.handler._map_userinfo_to_user.reset_mock()
self.handler._fetch_userinfo.reset_mock() self.handler._fetch_userinfo.reset_mock()
# With userinfo fetching # With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope self.handler._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {}, expected_user_id, request, client_redirect_url, {},
) )
self.handler._exchange_code.assert_called_once_with(code) self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called() self.handler._parse_id_token.assert_not_called()
self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called() self.render_error.assert_not_called()
@ -609,72 +591,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
userinfo = { userinfo = {
"sub": "foo", "sub": "foo",
"username": "foo",
"phone": "1234567", "phone": "1234567",
} }
user_id = "@foo:domain.org"
self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) auth_handler = self.hs.get_auth_handler()
self.handler._auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
state = "state" state = "state"
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
request.getCookie.return_value = self.handler._generate_oidc_session_token( session = self.handler._generate_oidc_session_token(
state=state, state=state,
nonce="nonce", nonce="nonce",
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request("code", state, session)
request.args = {}
request.args[b"code"] = [b"code"]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = "10.0.0.1"
request.get_user_agent.return_value = "Browser"
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {"phone": "1234567"}, "@foo:test", request, client_redirect_url, {"phone": "1234567"},
) )
def test_map_userinfo_to_user(self): def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
userinfo = { userinfo = {
"sub": "test_user", "sub": "test_user",
"username": "test_user", "username": "test_user",
} }
# The token doesn't matter with the default user mapping provider. self._make_callback_with_userinfo(userinfo)
token = {} auth_handler.complete_sso_login.assert_called_once_with(
mxid = self.get_success( "@test_user:test", ANY, ANY, {}
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
) )
self.assertEqual(mxid, "@test_user:test") auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID. # Some providers return an integer ID.
userinfo = { userinfo = {
"sub": 1234, "sub": 1234,
"username": "test_user_2", "username": "test_user_2",
} }
mxid = self.get_success( self._make_callback_with_userinfo(userinfo)
self.handler._map_userinfo_to_user( auth_handler.complete_sso_login.assert_called_once_with(
userinfo, token, "user-agent", "10.10.10.10" "@test_user_2:test", ANY, ANY, {}
)
) )
self.assertEqual(mxid, "@test_user_2:test") auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken # Test if the mxid is already taken
store = self.hs.get_datastore() store = self.hs.get_datastore()
@ -683,14 +648,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None) store.register_user(user_id=user3.to_string(), password_hash=None)
) )
userinfo = {"sub": "test3", "username": "test_user_3"} userinfo = {"sub": "test3", "username": "test_user_3"}
e = self.get_failure( self._make_callback_with_userinfo(userinfo)
self.handler._map_userinfo_to_user( auth_handler.complete_sso_login.assert_not_called()
userinfo, token, "user-agent", "10.10.10.10" self.assertRenderedError(
), "mapping_error",
MappingException, "Mapping provider does not support de-duplicating Matrix IDs",
)
self.assertEqual(
str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
) )
@override_config({"oidc_config": {"allow_existing_users": True}}) @override_config({"oidc_config": {"allow_existing_users": True}})
@ -702,26 +664,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None) store.register_user(user_id=user.to_string(), password_hash=None)
) )
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# Map a user via SSO. # Map a user via SSO.
userinfo = { userinfo = {
"sub": "test", "sub": "test",
"username": "test_user", "username": "test_user",
} }
token = {} self._make_callback_with_userinfo(userinfo)
mxid = self.get_success( auth_handler.complete_sso_login.assert_called_once_with(
self.handler._map_userinfo_to_user( user.to_string(), ANY, ANY, {},
userinfo, token, "user-agent", "10.10.10.10"
)
) )
self.assertEqual(mxid, "@test_user:test") auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
mxid = self.get_success( self._make_callback_with_userinfo(userinfo)
self.handler._map_userinfo_to_user( auth_handler.complete_sso_login.assert_called_once_with(
userinfo, token, "user-agent", "10.10.10.10" user.to_string(), ANY, ANY, {},
)
) )
self.assertEqual(mxid, "@test_user:test") auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This # Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID, # requires a unique sub, but something that maps to the same matrix ID,
@ -732,13 +694,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1", "sub": "test1",
"username": "test_user", "username": "test_user",
} }
token = {} self._make_callback_with_userinfo(userinfo)
mxid = self.get_success( auth_handler.complete_sso_login.assert_called_once_with(
self.handler._map_userinfo_to_user( user.to_string(), ANY, ANY, {},
userinfo, token, "user-agent", "10.10.10.10"
)
) )
self.assertEqual(mxid, "@test_user:test") auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases. # Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test") user2 = UserID.from_string("@TEST_user_2:test")
@ -755,14 +715,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2", "sub": "test2",
"username": "TEST_USER_2", "username": "TEST_USER_2",
} }
e = self.get_failure( self._make_callback_with_userinfo(userinfo)
self.handler._map_userinfo_to_user( auth_handler.complete_sso_login.assert_not_called()
userinfo, token, "user-agent", "10.10.10.10" args = self.assertRenderedError("mapping_error")
),
MappingException,
)
self.assertTrue( self.assertTrue(
str(e.value).startswith( args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
) )
) )
@ -773,28 +730,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None) store.register_user(user_id=user2.to_string(), password_hash=None)
) )
mxid = self.get_success( self._make_callback_with_userinfo(userinfo)
self.handler._map_userinfo_to_user( auth_handler.complete_sso_login.assert_called_once_with(
userinfo, token, "user-agent", "10.10.10.10" "@TEST_USER_2:test", ANY, ANY, {},
)
) )
self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self): def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
userinfo = { self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
"sub": "test2", self.assertRenderedError("mapping_error", "localpart is invalid: föö")
"username": "föö",
}
token = {}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
@override_config( @override_config(
{ {
@ -807,6 +751,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
def test_map_userinfo_to_user_retries(self): def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use.""" """The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
store = self.hs.get_datastore() store = self.hs.get_datastore()
self.get_success( self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None) store.register_user(user_id="@test_user:test", password_hash=None)
@ -815,14 +762,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test", "sub": "test",
"username": "test_user", "username": "test_user",
} }
token = {} self._make_callback_with_userinfo(userinfo)
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test") auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", ANY, ANY, {},
)
auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username. # Register all of the potential mxids for a particular OIDC username.
self.get_success( self.get_success(
@ -838,12 +784,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester", "sub": "tester",
"username": "tester", "username": "tester",
} }
e = self.get_failure( self._make_callback_with_userinfo(userinfo)
self.handler._map_userinfo_to_user( auth_handler.complete_sso_login.assert_not_called()
userinfo, token, "user-agent", "10.10.10.10" self.assertRenderedError(
), "mapping_error", "Unable to generate a Matrix ID from the SSO response"
MappingException,
) )
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response" def _make_callback_with_userinfo(
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None:
self.handler._exchange_code = simple_async_mock(return_value={})
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
state = "state"
session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
) )
request = self._build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
def _build_callback_request(
self,
code: str,
state: str,
session: str,
user_agent: str = "Browser",
ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
after SSO (before we return to the client)
Args:
code: the authorization code which would have been returned by the OIDC
provider
state: the "state" param which would have been passed around in the
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
return request