Make _make_callback_with_userinfo async

... so that we can test its behaviour when it raises.

Also pull it out to the top level so that I can use it from other test classes.
This commit is contained in:
Richard van der Hoff 2020-12-15 13:03:31 +00:00
parent c1883f042d
commit 8388a7fb3a

View File

@ -21,6 +21,7 @@ import pymacaroons
from synapse.handlers.oidc_handler import OidcError
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
from tests.test_utils import FakeResponse, simple_async_mock
@ -399,7 +400,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request(
request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address
)
@ -607,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request("code", state, session)
request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
@ -624,7 +625,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test_user",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, None,
)
@ -635,7 +636,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": 1234,
"username": "test_user_2",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, None,
)
@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
@ -672,14 +673,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None,
)
@ -694,7 +695,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None,
)
@ -715,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
@ -730,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, None,
)
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@ -762,7 +765,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
@ -784,68 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
def _make_callback_with_userinfo(
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None:
self.handler._exchange_code = simple_async_mock(return_value={})
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state"
session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request("code", state, session)
async def _make_callback_with_userinfo(
hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None:
"""Mock up an OIDC callback with the given userinfo dict
self.get_success(self.handler.handle_oidc_callback(request))
We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
def _build_callback_request(
self,
code: str,
state: str,
session: str,
user_agent: str = "Browser",
ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
Args:
hs: the HomeServer impl to send the callback to.
userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success.
"""
handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={})
handler._parse_id_token = simple_async_mock(return_value=userinfo)
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
Returns a Mock object which looks like the SynapseRequest we get from a browser
after SSO (before we return to the client)
state = "state"
session = handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = _build_callback_request("code", state, session)
Args:
code: the authorization code which would have been returned by the OIDC
provider
state: the "state" param which would have been passed around in the
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
await handler.handle_oidc_callback(request)
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
return request
def _build_callback_request(
code: str,
state: str,
session: str,
user_agent: str = "Browser",
ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
after SSO (before we return to the client)
Args:
code: the authorization code which would have been returned by the OIDC
provider
state: the "state" param which would have been passed around in the
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
return request