mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-10 10:30:03 -04:00
Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com> This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
parent
ecd823d766
commit
947dbbdfd1
21 changed files with 2389 additions and 1 deletions
|
@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
|||
from synapse.api.errors import Codes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
||||
from synapse.storage._base import db_to_json
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_requires_token(self):
|
||||
username = "kermit"
|
||||
device_id = "frogfone"
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
params = {
|
||||
"username": username,
|
||||
"password": "monkey",
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
# Request without auth to get flows and session
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
flows = channel.json_body["flows"]
|
||||
# Synapse adds a dummy stage to differentiate flows where otherwise one
|
||||
# flow would be a subset of another flow.
|
||||
self.assertCountEqual(
|
||||
[[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
|
||||
(f["stages"] for f in flows),
|
||||
)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Do the registration token stage and check it has completed
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
completed = channel.json_body["completed"]
|
||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||
|
||||
# Do the m.login.dummy stage and check registration was successful
|
||||
params["auth"] = {
|
||||
"type": LoginType.DUMMY,
|
||||
"session": session,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
det_data = {
|
||||
"user_id": f"@{username}:{self.hs.hostname}",
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||
|
||||
# Check the `completed` counter has been incremented and pending is 0
|
||||
res = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEquals(res["completed"], 1)
|
||||
self.assertEquals(res["pending"], 0)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_invalid(self):
|
||||
params = {
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
}
|
||||
# Request without auth to get session
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Test with token param missing (invalid)
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Test with non-string (invalid)
|
||||
params["auth"]["token"] = 1234
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Test with unknown token (invalid)
|
||||
params["auth"]["token"] = "1234"
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_limit_uses(self):
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
# Create token that can be used once
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": 1,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
params1 = {"username": "bert", "password": "monkey"}
|
||||
params2 = {"username": "ernie", "password": "monkey"}
|
||||
# Do 2 requests without auth to get two session IDs
|
||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
session1 = channel1.json_body["session"]
|
||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
session2 = channel2.json_body["session"]
|
||||
|
||||
# Use token with session1 and check `pending` is 1
|
||||
params1["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session1,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
# Repeat request to make sure pending isn't increased again
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
pending = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
)
|
||||
self.assertEquals(pending, 1)
|
||||
|
||||
# Check auth fails when using token with session2
|
||||
params2["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session2,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Complete registration with session1
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
# Check pending=0 and completed=1
|
||||
res = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEquals(res["pending"], 0)
|
||||
self.assertEquals(res["completed"], 1)
|
||||
|
||||
# Check auth still fails when using token with session2
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_expiry(self):
|
||||
token = "abcd"
|
||||
now = self.hs.get_clock().time_msec()
|
||||
store = self.hs.get_datastore()
|
||||
# Create token that expired yesterday
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": now - 24 * 60 * 60 * 1000,
|
||||
},
|
||||
)
|
||||
)
|
||||
params = {"username": "kermit", "password": "monkey"}
|
||||
# Request without auth to get session
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Check authentication fails with expired token
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Update token so it expires tomorrow
|
||||
self.get_success(
|
||||
store.db_pool.simple_update_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
|
||||
)
|
||||
)
|
||||
|
||||
# Check authentication succeeds
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
completed = channel.json_body["completed"]
|
||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_session_expiry(self):
|
||||
"""Test `pending` is decremented when an uncompleted session expires."""
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Do 2 requests without auth to get two session IDs
|
||||
params1 = {"username": "bert", "password": "monkey"}
|
||||
params2 = {"username": "ernie", "password": "monkey"}
|
||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
session1 = channel1.json_body["session"]
|
||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
session2 = channel2.json_body["session"]
|
||||
|
||||
# Use token with both sessions
|
||||
params1["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session1,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
|
||||
params2["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session2,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
|
||||
# Complete registration with session1
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
|
||||
# Check `result` of registration token stage for session1 is `True`
|
||||
result1 = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"ui_auth_sessions_credentials",
|
||||
keyvalues={
|
||||
"session_id": session1,
|
||||
"stage_type": LoginType.REGISTRATION_TOKEN,
|
||||
},
|
||||
retcol="result",
|
||||
)
|
||||
)
|
||||
self.assertTrue(db_to_json(result1))
|
||||
|
||||
# Check `result` for session2 is the token used
|
||||
result2 = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"ui_auth_sessions_credentials",
|
||||
keyvalues={
|
||||
"session_id": session2,
|
||||
"stage_type": LoginType.REGISTRATION_TOKEN,
|
||||
},
|
||||
retcol="result",
|
||||
)
|
||||
)
|
||||
self.assertEquals(db_to_json(result2), token)
|
||||
|
||||
# Delete both sessions (mimics expiry)
|
||||
self.get_success(
|
||||
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||
)
|
||||
|
||||
# Check pending is now 0
|
||||
pending = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
)
|
||||
self.assertEquals(pending, 0)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_session_expiry_deleted_token(self):
|
||||
"""Test session expiry doesn't break when the token is deleted.
|
||||
|
||||
1. Start but don't complete UIA with a registration token
|
||||
2. Delete the token from the database
|
||||
3. Expire the session
|
||||
"""
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Do request without auth to get a session ID
|
||||
params = {"username": "kermit", "password": "monkey"}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Use token
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params))
|
||||
|
||||
# Delete token
|
||||
self.get_success(
|
||||
store.db_pool.simple_delete_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
)
|
||||
)
|
||||
|
||||
# Delete session (mimics expiry)
|
||||
self.get_success(
|
||||
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||
)
|
||||
|
||||
def test_advertised_flows(self):
|
||||
channel = self.make_request(b"POST", self.url, b"{}")
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
|
@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
|
||||
self.assertLessEqual(res, now_ms + self.validity_period)
|
||||
|
||||
|
||||
class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [register.register_servlets]
|
||||
url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["registration_requires_token"] = True
|
||||
return config
|
||||
|
||||
def test_GET_token_valid(self):
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertEquals(channel.json_body["valid"], True)
|
||||
|
||||
def test_GET_token_invalid(self):
|
||||
token = "1234"
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertEquals(channel.json_body["valid"], False)
|
||||
|
||||
@override_config(
|
||||
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
|
||||
)
|
||||
def test_GET_ratelimiting(self):
|
||||
token = "1234"
|
||||
|
||||
for i in range(0, 6):
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
|
||||
if i == 5:
|
||||
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||
else:
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue