Fix a bug in automatic user creation with m.login.jwt. (#7585)

This commit is contained in:
Olof Johansson 2020-06-01 18:55:07 +02:00 committed by GitHub
parent 33c39ab93c
commit fe434cd3c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 162 additions and 7 deletions

1
changelog.d/7585.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug in automatic user creation during first time login with `m.login.jwt`. Regression in v1.6.0. Contributed by @olof.

View File

@ -299,7 +299,7 @@ class LoginRestServlet(RestServlet):
return result return result
async def _complete_login( async def _complete_login(
self, user_id, login_submission, callback=None, create_non_existant_users=False self, user_id, login_submission, callback=None, create_non_existent_users=False
): ):
"""Called when we've successfully authed the user and now need to """Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on actually login them in (e.g. create devices). This gets called on
@ -312,7 +312,7 @@ class LoginRestServlet(RestServlet):
user_id (str): ID of the user to register. user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information. login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration. callback (func|None): Callback function to run after registration.
create_non_existant_users (bool): Whether to create the user if create_non_existent_users (bool): Whether to create the user if
they don't exist. Defaults to False. they don't exist. Defaults to False.
Returns: Returns:
@ -331,12 +331,13 @@ class LoginRestServlet(RestServlet):
update=True, update=True,
) )
if create_non_existant_users: if create_non_existent_users:
user_id = await self.auth_handler.check_user_exists(user_id) canonical_uid = await self.auth_handler.check_user_exists(user_id)
if not user_id: if not canonical_uid:
user_id = await self.registration_handler.register_user( canonical_uid = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart localpart=UserID.from_string(user_id).localpart
) )
user_id = canonical_uid
device_id = login_submission.get("device_id") device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name") initial_display_name = login_submission.get("initial_device_display_name")
@ -391,7 +392,7 @@ class LoginRestServlet(RestServlet):
user_id = UserID(user, self.hs.hostname).to_string() user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login( result = await self._complete_login(
user_id, login_submission, create_non_existant_users=True user_id, login_submission, create_non_existent_users=True
) )
return result return result

View File

@ -1,8 +1,11 @@
import json import json
import time
import urllib.parse import urllib.parse
from mock import Mock from mock import Mock
import jwt
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client.v1 import login, logout from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha import devices
@ -473,3 +476,153 @@ class CASTestCase(unittest.HomeserverTestCase):
# Because the user is deactivated they are served an error template. # Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"]) self.assertIn(b"SSO account deactivated", channel.result["body"])
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]
jwt_secret = "secret"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
self.hs.config.jwt_algorithm = "HS256"
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
return jwt.encode(token, secret, "HS256").decode("ascii")
def jwt_login(self, *args):
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
def test_login_jwt_valid_registered(self):
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self):
channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "JWT expired")
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
def test_login_no_token(self):
params = json.dumps({"type": "m.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
]
# This key's pubkey is used as the jwt_secret setting of synapse. Valid
# tokens are signed by this and validated using the pubkey. It is generated
# with `openssl genrsa 512` (not a secure way to generate real keys, but
# good enough for tests!)
jwt_privatekey = "\n".join(
[
"-----BEGIN RSA PRIVATE KEY-----",
"MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
"492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
"yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
"kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
"TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
"ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
"tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
"-----END RSA PRIVATE KEY-----",
]
)
# Generated with `openssl rsa -in foo.key -pubout`, with the the above
# private key placed in foo.key (jwt_privatekey).
jwt_pubkey = "\n".join(
[
"-----BEGIN PUBLIC KEY-----",
"MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
"TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
"-----END PUBLIC KEY-----",
]
)
# This key is used to sign tokens that shouldn't be accepted by synapse.
# Generated just like jwt_privatekey.
bad_privatekey = "\n".join(
[
"-----BEGIN RSA PRIVATE KEY-----",
"MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
"gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
"R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
"uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
"eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
"iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
"KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
"-----END RSA PRIVATE KEY-----",
]
)
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_pubkey
self.hs.config.jwt_algorithm = "RS256"
return self.hs
def jwt_encode(self, token, secret=jwt_privatekey):
return jwt.encode(token, secret, "RS256").decode("ascii")
def jwt_login(self, *args):
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
def test_login_jwt_valid(self):
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")