mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Record the SSO Auth Provider in the login token (#9510)
This great big stack of commits is a a whole load of hoop-jumping to make it easier to store additional values in login tokens, and then to actually store the SSO Identity Provider in the login token. (Making use of that data will follow in a subsequent PR.)
This commit is contained in:
parent
a6333b8d42
commit
7eb6e39a8f
1
changelog.d/9510.feature
Normal file
1
changelog.d/9510.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add prometheus metrics for number of users successfully registering and logging in.
|
1
mypy.ini
1
mypy.ini
@ -69,6 +69,7 @@ files =
|
|||||||
synapse/util/async_helpers.py,
|
synapse/util/async_helpers.py,
|
||||||
synapse/util/caches,
|
synapse/util/caches,
|
||||||
synapse/util/metrics.py,
|
synapse/util/metrics.py,
|
||||||
|
synapse/util/macaroons.py,
|
||||||
synapse/util/stringutils.py,
|
synapse/util/stringutils.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
tests/test_utils,
|
tests/test_utils,
|
||||||
|
@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
|
|||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import StateMap, UserID
|
from synapse.types import StateMap, UserID
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -408,7 +409,7 @@ class Auth:
|
|||||||
raise _InvalidMacaroonException()
|
raise _InvalidMacaroonException()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||||
|
|
||||||
guest = False
|
guest = False
|
||||||
for caveat in macaroon.caveats:
|
for caveat in macaroon.caveats:
|
||||||
@ -416,7 +417,12 @@ class Auth:
|
|||||||
guest = True
|
guest = True
|
||||||
|
|
||||||
self.validate_macaroon(macaroon, rights, user_id=user_id)
|
self.validate_macaroon(macaroon, rights, user_id=user_id)
|
||||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
except (
|
||||||
|
pymacaroons.exceptions.MacaroonException,
|
||||||
|
KeyError,
|
||||||
|
TypeError,
|
||||||
|
ValueError,
|
||||||
|
):
|
||||||
raise InvalidClientTokenError("Invalid macaroon passed.")
|
raise InvalidClientTokenError("Invalid macaroon passed.")
|
||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
@ -424,27 +430,6 @@ class Auth:
|
|||||||
|
|
||||||
return user_id, guest
|
return user_id, guest
|
||||||
|
|
||||||
def get_user_id_from_macaroon(self, macaroon):
|
|
||||||
"""Retrieve the user_id given by the caveats on the macaroon.
|
|
||||||
|
|
||||||
Does *not* validate the macaroon.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
macaroon (pymacaroons.Macaroon): The macaroon to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(str) user id
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidClientCredentialsError if there is no user_id caveat in the
|
|
||||||
macaroon
|
|
||||||
"""
|
|
||||||
user_prefix = "user_id = "
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(user_prefix):
|
|
||||||
return caveat.caveat_id[len(user_prefix) :]
|
|
||||||
raise InvalidClientTokenError("No user caveat in macaroon")
|
|
||||||
|
|
||||||
def validate_macaroon(self, macaroon, type_string, user_id):
|
def validate_macaroon(self, macaroon, type_string, user_id):
|
||||||
"""
|
"""
|
||||||
validate that a Macaroon is understood by and was signed by this server.
|
validate that a Macaroon is understood by and was signed by this server.
|
||||||
@ -465,21 +450,13 @@ class Auth:
|
|||||||
v.satisfy_exact("type = " + type_string)
|
v.satisfy_exact("type = " + type_string)
|
||||||
v.satisfy_exact("user_id = %s" % user_id)
|
v.satisfy_exact("user_id = %s" % user_id)
|
||||||
v.satisfy_exact("guest = true")
|
v.satisfy_exact("guest = true")
|
||||||
v.satisfy_general(self._verify_expiry)
|
satisfy_expiry(v, self.clock.time_msec)
|
||||||
|
|
||||||
# access_tokens include a nonce for uniqueness: any value is acceptable
|
# access_tokens include a nonce for uniqueness: any value is acceptable
|
||||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
|
|
||||||
v.verify(macaroon, self._macaroon_secret_key)
|
v.verify(macaroon, self._macaroon_secret_key)
|
||||||
|
|
||||||
def _verify_expiry(self, caveat):
|
|
||||||
prefix = "time < "
|
|
||||||
if not caveat.startswith(prefix):
|
|
||||||
return False
|
|
||||||
expiry = int(caveat[len(prefix) :])
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
return now < expiry
|
|
||||||
|
|
||||||
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
||||||
token = self.get_access_token_from_request(request)
|
token = self.get_access_token_from_request(request)
|
||||||
service = self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
|
@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
|
|||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util import stringutils as stringutils
|
from synapse.util import stringutils as stringutils
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
|
|||||||
extra_attributes = attr.ib(type=JsonDict)
|
extra_attributes = attr.ib(type=JsonDict)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class LoginTokenAttributes:
|
||||||
|
"""Data we store in a short-term login token"""
|
||||||
|
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
# the SSO Identity Provider that the user authenticated with, to get this token
|
||||||
|
auth_provider_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
|
|||||||
return None
|
return None
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
|
async def validate_short_term_login_token(
|
||||||
auth_api = self.hs.get_auth()
|
self, login_token: str
|
||||||
user_id = None
|
) -> LoginTokenAttributes:
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
res = self.macaroon_gen.verify_short_term_login_token(login_token)
|
||||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
|
||||||
auth_api.validate_macaroon(macaroon, "login", user_id)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(res.user_id)
|
||||||
return user_id
|
return res
|
||||||
|
|
||||||
async def delete_access_token(self, access_token: str):
|
async def delete_access_token(self, access_token: str):
|
||||||
"""Invalidate a single access token
|
"""Invalidate a single access token
|
||||||
@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
|
|||||||
async def complete_sso_login(
|
async def complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
|
auth_provider_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
registered_user_id: The registered user ID to complete SSO login for.
|
registered_user_id: The registered user ID to complete SSO login for.
|
||||||
|
auth_provider_id: The id of the SSO Identity provider that was used for
|
||||||
|
login. This will be stored in the login token for future tracking in
|
||||||
|
prometheus metrics.
|
||||||
request: The request to complete.
|
request: The request to complete.
|
||||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||||
process.
|
process.
|
||||||
@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
self._complete_sso_login(
|
self._complete_sso_login(
|
||||||
registered_user_id,
|
registered_user_id,
|
||||||
|
auth_provider_id,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
extra_attributes,
|
extra_attributes,
|
||||||
@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
|
|||||||
def _complete_sso_login(
|
def _complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
|
auth_provider_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
# Create a login token
|
# Create a login token
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||||
registered_user_id
|
registered_user_id, auth_provider_id=auth_provider_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append the login token to the original redirect URL (i.e. with its query
|
# Append the login token to the original redirect URL (i.e. with its query
|
||||||
@ -1569,15 +1584,48 @@ class MacaroonGenerator:
|
|||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
def generate_short_term_login_token(
|
||||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
self,
|
||||||
|
user_id: str,
|
||||||
|
auth_provider_id: str,
|
||||||
|
duration_in_ms: int = (2 * 60 * 1000),
|
||||||
) -> str:
|
) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
||||||
|
"""Verify a short-term-login macaroon
|
||||||
|
|
||||||
|
Checks that the given token is a valid, unexpired short-term-login token
|
||||||
|
minted by this server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: the login token to verify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the user_id that this token is valid for
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MacaroonVerificationFailedException if the verification failed
|
||||||
|
"""
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||||
|
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
||||||
|
|
||||||
|
v = pymacaroons.Verifier()
|
||||||
|
v.satisfy_exact("gen = 1")
|
||||||
|
v.satisfy_exact("type = login")
|
||||||
|
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
||||||
|
satisfy_expiry(v, self.hs.get_clock().time_msec)
|
||||||
|
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||||
|
|
||||||
|
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
|
||||||
|
|
||||||
def generate_delete_pusher_token(self, user_id: str) -> str:
|
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||||
|
@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
|
|||||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -211,7 +212,7 @@ class OidcHandler:
|
|||||||
session_data = self._token_generator.verify_oidc_session_token(
|
session_data = self._token_generator.verify_oidc_session_token(
|
||||||
session, state
|
session, state
|
||||||
)
|
)
|
||||||
except (MacaroonDeserializationException, ValueError) as e:
|
except (MacaroonDeserializationException, KeyError) as e:
|
||||||
logger.exception("Invalid session for OIDC callback")
|
logger.exception("Invalid session for OIDC callback")
|
||||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||||
return
|
return
|
||||||
@ -745,7 +746,7 @@ class OidcProvider:
|
|||||||
idp_id=self.idp_id,
|
idp_id=self.idp_id,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
client_redirect_url=client_redirect_url.decode(),
|
client_redirect_url=client_redirect_url.decode(),
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id or "",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
|
|||||||
macaroon.add_first_party_caveat(
|
macaroon.add_first_party_caveat(
|
||||||
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||||
)
|
)
|
||||||
if session_data.ui_auth_session_id:
|
macaroon.add_first_party_caveat(
|
||||||
macaroon.add_first_party_caveat(
|
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
||||||
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
)
|
||||||
)
|
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
|
|||||||
The data extracted from the session cookie
|
The data extracted from the session cookie
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError if an expected caveat is missing from the macaroon.
|
KeyError if an expected caveat is missing from the macaroon.
|
||||||
"""
|
"""
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||||
|
|
||||||
@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
|
|||||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
||||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
|
||||||
# to always satisfy this.
|
|
||||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||||
v.satisfy_general(self._verify_expiry)
|
satisfy_expiry(v, self._clock.time_msec)
|
||||||
|
|
||||||
v.verify(macaroon, self._macaroon_secret_key)
|
v.verify(macaroon, self._macaroon_secret_key)
|
||||||
|
|
||||||
# Extract the session data from the token.
|
# Extract the session data from the token.
|
||||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
nonce = get_value_from_macaroon(macaroon, "nonce")
|
||||||
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
|
idp_id = get_value_from_macaroon(macaroon, "idp_id")
|
||||||
client_redirect_url = self._get_value_from_macaroon(
|
client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
|
||||||
macaroon, "client_redirect_url"
|
ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
|
||||||
)
|
|
||||||
try:
|
|
||||||
ui_auth_session_id = self._get_value_from_macaroon(
|
|
||||||
macaroon, "ui_auth_session_id"
|
|
||||||
) # type: Optional[str]
|
|
||||||
except ValueError:
|
|
||||||
ui_auth_session_id = None
|
|
||||||
|
|
||||||
return OidcSessionData(
|
return OidcSessionData(
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
idp_id=idp_id,
|
idp_id=idp_id,
|
||||||
@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
|
|||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
|
||||||
"""Extracts a caveat value from a macaroon token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
macaroon: the token
|
|
||||||
key: the key of the caveat to extract
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The extracted value
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if the caveat was not in the macaroon
|
|
||||||
"""
|
|
||||||
prefix = key + " = "
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(prefix):
|
|
||||||
return caveat.caveat_id[len(prefix) :]
|
|
||||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
|
||||||
|
|
||||||
def _verify_expiry(self, caveat: str) -> bool:
|
|
||||||
prefix = "time < "
|
|
||||||
if not caveat.startswith(prefix):
|
|
||||||
return False
|
|
||||||
expiry = int(caveat[len(prefix) :])
|
|
||||||
now = self._clock.time_msec()
|
|
||||||
return now < expiry
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True)
|
@attr.s(frozen=True, slots=True)
|
||||||
class OidcSessionData:
|
class OidcSessionData:
|
||||||
@ -1125,8 +1088,8 @@ class OidcSessionData:
|
|||||||
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
||||||
client_redirect_url = attr.ib(type=str)
|
client_redirect_url = attr.ib(type=str)
|
||||||
|
|
||||||
# The session ID of the ongoing UI Auth (None if this is a login)
|
# The session ID of the ongoing UI Auth ("" if this is a login)
|
||||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
ui_auth_session_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
UserAttributeDict = TypedDict(
|
UserAttributeDict = TypedDict(
|
||||||
|
@ -456,6 +456,7 @@ class SsoHandler:
|
|||||||
|
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
user_id,
|
user_id,
|
||||||
|
auth_provider_id,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
extra_login_attributes,
|
extra_login_attributes,
|
||||||
@ -886,6 +887,7 @@ class SsoHandler:
|
|||||||
|
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
user_id,
|
user_id,
|
||||||
|
session.auth_provider_id,
|
||||||
request,
|
request,
|
||||||
session.client_redirect_url,
|
session.client_redirect_url,
|
||||||
session.extra_login_attributes,
|
session.extra_login_attributes,
|
||||||
|
@ -203,11 +203,26 @@ class ModuleApi:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
def generate_short_term_login_token(
|
||||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
self,
|
||||||
|
user_id: str,
|
||||||
|
duration_in_ms: int = (2 * 60 * 1000),
|
||||||
|
auth_provider_id: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a login token suitable for m.login.token authentication"""
|
"""Generate a login token suitable for m.login.token authentication
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: gives the ID of the user that the token is for
|
||||||
|
|
||||||
|
duration_in_ms: the time that the token will be valid for
|
||||||
|
|
||||||
|
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
|
||||||
|
to get this token, if any. This is encoded in the token so that
|
||||||
|
/login can report stats on number of successful logins by IdP.
|
||||||
|
"""
|
||||||
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
||||||
user_id, duration_in_ms
|
user_id,
|
||||||
|
auth_provider_id,
|
||||||
|
duration_in_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -276,6 +291,7 @@ class ModuleApi:
|
|||||||
"""
|
"""
|
||||||
self._auth_handler._complete_sso_login(
|
self._auth_handler._complete_sso_login(
|
||||||
registered_user_id,
|
registered_user_id,
|
||||||
|
"<unknown>",
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
)
|
)
|
||||||
@ -286,6 +302,7 @@ class ModuleApi:
|
|||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
new_user: bool = False,
|
new_user: bool = False,
|
||||||
|
auth_provider_id: str = "<unknown>",
|
||||||
):
|
):
|
||||||
"""Complete a SSO login by redirecting the user to a page to confirm whether they
|
"""Complete a SSO login by redirecting the user to a page to confirm whether they
|
||||||
want their access token sent to `client_redirect_url`, or redirect them to that
|
want their access token sent to `client_redirect_url`, or redirect them to that
|
||||||
@ -299,9 +316,15 @@ class ModuleApi:
|
|||||||
redirect them directly if whitelisted).
|
redirect them directly if whitelisted).
|
||||||
new_user: set to true to use wording for the consent appropriate to a user
|
new_user: set to true to use wording for the consent appropriate to a user
|
||||||
who has just registered.
|
who has just registered.
|
||||||
|
auth_provider_id: the ID of the SSO IdP which was used to log in. This
|
||||||
|
is used to track counts of sucessful logins by IdP.
|
||||||
"""
|
"""
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
registered_user_id, request, client_redirect_url, new_user=new_user
|
registered_user_id,
|
||||||
|
auth_provider_id,
|
||||||
|
request,
|
||||||
|
client_redirect_url,
|
||||||
|
new_user=new_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -283,12 +283,10 @@ class LoginRestServlet(RestServlet):
|
|||||||
"""
|
"""
|
||||||
token = login_submission["token"]
|
token = login_submission["token"]
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
res = await auth_handler.validate_short_term_login_token(token)
|
||||||
token
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._complete_login(
|
return await self._complete_login(
|
||||||
user_id, login_submission, self.auth_handler._sso_login_callback
|
res.user_id, login_submission, self.auth_handler._sso_login_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||||
|
89
synapse/util/macaroons.py
Normal file
89
synapse/util/macaroons.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 Quentin Gliech
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Utilities for manipulating macaroons"""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import pymacaroons
|
||||||
|
from pymacaroons.exceptions import MacaroonVerificationFailedException
|
||||||
|
|
||||||
|
|
||||||
|
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||||
|
"""Extracts a caveat value from a macaroon token.
|
||||||
|
|
||||||
|
Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
|
||||||
|
and returns the extracted value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
macaroon: the token
|
||||||
|
key: the key of the caveat to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MacaroonVerificationFailedException: if there are conflicting values for the
|
||||||
|
caveat in the macaroon, or if the caveat was not found in the macaroon.
|
||||||
|
"""
|
||||||
|
prefix = key + " = "
|
||||||
|
result = None # type: Optional[str]
|
||||||
|
for caveat in macaroon.caveats:
|
||||||
|
if not caveat.caveat_id.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
val = caveat.caveat_id[len(prefix) :]
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
# first time we found this caveat: record the value
|
||||||
|
result = val
|
||||||
|
elif val != result:
|
||||||
|
# on subsequent occurrences, raise if the value is different.
|
||||||
|
raise MacaroonVerificationFailedException(
|
||||||
|
"Conflicting values for caveat " + key
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If the caveat is not there, we raise a MacaroonVerificationFailedException.
|
||||||
|
# Note that it is insecure to generate a macaroon without all the caveats you
|
||||||
|
# might need (because there is nothing stopping people from adding extra caveats),
|
||||||
|
# so if the caveat isn't there, something odd must be going on.
|
||||||
|
raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
|
||||||
|
|
||||||
|
|
||||||
|
def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
|
||||||
|
"""Make a macaroon verifier which accepts 'time' caveats
|
||||||
|
|
||||||
|
Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
|
||||||
|
the given macaroon verifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: the macaroon verifier
|
||||||
|
get_time_ms: a callable which will return the timestamp after which the caveat
|
||||||
|
should be considered expired. Normally the current time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def verify_expiry_caveat(caveat: str):
|
||||||
|
time_msec = get_time_ms()
|
||||||
|
prefix = "time < "
|
||||||
|
if not caveat.startswith(prefix):
|
||||||
|
return False
|
||||||
|
expiry = int(caveat[len(prefix) :])
|
||||||
|
return time_msec < expiry
|
||||||
|
|
||||||
|
v.satisfy_general(verify_expiry_caveat)
|
@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
user_id = self.get_success(
|
"a_user", "", 5000
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
|
||||||
)
|
)
|
||||||
self.assertEqual("a_user", user_id)
|
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||||
|
self.assertEqual("a_user", res.user_id)
|
||||||
|
self.assertEqual("", res.auth_provider_id)
|
||||||
|
|
||||||
# when we advance the clock, the token should be rejected
|
# when we advance the clock, the token should be rejected
|
||||||
self.reactor.advance(6)
|
self.reactor.advance(6)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
|
self.auth_handler.validate_short_term_login_token(token),
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_short_term_login_token_gives_auth_provider(self):
|
||||||
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"a_user", auth_provider_id="my_idp"
|
||||||
|
)
|
||||||
|
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||||
|
self.assertEqual("a_user", res.user_id)
|
||||||
|
self.assertEqual("my_idp", res.auth_provider_id)
|
||||||
|
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"a_user", "", 5000
|
||||||
|
)
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
user_id = self.get_success(
|
res = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
|
||||||
macaroon.serialize()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.assertEqual("a_user", user_id)
|
self.assertEqual("a_user", res.user_id)
|
||||||
|
|
||||||
# add another "user_id" caveat, which might allow us to override the
|
# add another "user_id" caveat, which might allow us to override the
|
||||||
# user_id.
|
# user_id.
|
||||||
macaroon.add_first_party_caveat("user_id = b_user")
|
macaroon.add_first_party_caveat("user_id = b_user")
|
||||||
|
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
|
||||||
macaroon.serialize()
|
|
||||||
),
|
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
return_value=make_awaitable(self.large_number_of_users)
|
return_value=make_awaitable(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
),
|
),
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
),
|
),
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
return_value=make_awaitable(self.small_number_of_users)
|
return_value=make_awaitable(self.small_number_of_users)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_macaroon(self):
|
def _get_macaroon(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"user_a", "", 5000
|
||||||
|
)
|
||||||
return pymacaroons.Macaroon.deserialize(token)
|
return pymacaroons.Macaroon.deserialize(token)
|
||||||
|
@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_existing_user(self):
|
def test_map_cas_user_to_existing_user(self):
|
||||||
@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=False
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
|||||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=False
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_invalid_localpart(self):
|
def test_map_cas_user_to_invalid_localpart(self):
|
||||||
@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
|
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from mock import ANY, Mock, patch
|
from mock import ANY, Mock, patch
|
||||||
@ -23,6 +22,7 @@ import pymacaroons
|
|||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
self.assertEqual(name, b"oidc_session")
|
self.assertEqual(name, b"oidc_session")
|
||||||
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(cookie)
|
macaroon = pymacaroons.Macaroon.deserialize(cookie)
|
||||||
state = self.handler._token_generator._get_value_from_macaroon(
|
state = get_value_from_macaroon(macaroon, "state")
|
||||||
macaroon, "state"
|
nonce = get_value_from_macaroon(macaroon, "nonce")
|
||||||
)
|
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
|
||||||
nonce = self.handler._token_generator._get_value_from_macaroon(
|
|
||||||
macaroon, "nonce"
|
|
||||||
)
|
|
||||||
redirect = self.handler._token_generator._get_value_from_macaroon(
|
|
||||||
macaroon, "client_redirect_url"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(params["state"], [state])
|
self.assertEqual(params["state"], [state])
|
||||||
self.assertEqual(params["nonce"], [nonce])
|
self.assertEqual(params["nonce"], [nonce])
|
||||||
@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, request, client_redirect_url, None, new_user=True
|
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
|
||||||
)
|
)
|
||||||
self.provider._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, request, client_redirect_url, None, new_user=False
|
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
|
||||||
)
|
)
|
||||||
self.provider._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.provider._parse_id_token.assert_not_called()
|
self.provider._parse_id_token.assert_not_called()
|
||||||
@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@foo:test",
|
"@foo:test",
|
||||||
|
"oidc",
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
{"phone": "1234567"},
|
{"phone": "1234567"},
|
||||||
@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", ANY, ANY, None, new_user=True
|
"@test_user:test", "oidc", ANY, ANY, None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user_2:test", ANY, ANY, None, new_user=True
|
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None, new_user=False
|
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None, new_user=False
|
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None, new_user=False
|
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@TEST_USER_2:test", ANY, ANY, None, new_user=False
|
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_userinfo_to_invalid_localpart(self):
|
def test_map_userinfo_to_invalid_localpart(self):
|
||||||
@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", ANY, ANY, None, new_user=True
|
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||||||
state: str,
|
state: str,
|
||||||
nonce: str,
|
nonce: str,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
ui_auth_session_id: Optional[str] = None,
|
ui_auth_session_id: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
from synapse.handlers.oidc_handler import OidcSessionData
|
from synapse.handlers.oidc_handler import OidcSessionData
|
||||||
|
|
||||||
@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
|
|||||||
idp_id="oidc",
|
idp_id="oidc",
|
||||||
nonce="nonce",
|
nonce="nonce",
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
|
ui_auth_session_id="",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
request = _build_callback_request("code", state, session)
|
request = _build_callback_request("code", state, session)
|
||||||
|
@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "", None, new_user=False
|
"@test_user:test", "saml", request, "", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
|||||||
self.handler._handle_authn_response(request, saml_response, "")
|
self.handler._handle_authn_response(request, saml_response, "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "", None, new_user=False
|
"@test_user:test", "saml", request, "", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_saml_response_to_invalid_localpart(self):
|
def test_map_saml_response_to_invalid_localpart(self):
|
||||||
@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", request, "", None, new_user=True
|
"@test_user1:test", "saml", request, "", None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user