Add config for customizing the claim used for JWT logins. (#11361)

Allows specifying a different claim (from the default "sub") to use
when calculating the localpart of the Matrix ID used during the
JWT login.
This commit is contained in:
Kostas 2021-11-22 19:01:03 +01:00 committed by GitHub
parent 3d893b8cf2
commit 1035663833
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 35 deletions

View file

@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
jwt_secret = "secret"
jwt_algorithm = "HS256"
base_config = {
"enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
}
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt.jwt_enabled = True
self.hs.config.jwt.jwt_secret = self.jwt_secret
self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
return self.hs
def default_config(self):
config = super().default_config()
# If jwt_config has been defined (eg via @override_config), don't replace it.
if config.get("jwt_config") is None:
config["jwt_config"] = self.base_config
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@ -879,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config(
{
"jwt_config": {
"jwt_enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
"issuer": "test-issuer",
}
}
)
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self):
"""Test validating the issuer claim."""
# A valid issuer.
@ -919,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config(
{
"jwt_config": {
"jwt_enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
"audiences": ["test-audience"],
}
}
)
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self):
"""Test validating the audience claim."""
# A valid audience.
@ -962,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
def test_login_default_sub(self):
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self):
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self):
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
]
)
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt.jwt_enabled = True
self.hs.config.jwt.jwt_secret = self.jwt_pubkey
self.hs.config.jwt.jwt_algorithm = "RS256"
return self.hs
def default_config(self):
config = super().default_config()
config["jwt_config"] = {
"enabled": True,
"secret": self.jwt_pubkey,
"algorithm": "RS256",
}
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.