Implement MSC3231: Token authenticated registration (#10142)

Signed-off-by: Callum Brown <callum@calcuode.com>

This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
Callum Brown 2021-08-21 22:14:43 +01:00 committed by GitHub
parent ecd823d766
commit 947dbbdfd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 2389 additions and 1 deletions

View file

@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self):
username = "kermit"
device_id = "frogfone"
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
params = {
"username": username,
"password": "monkey",
"device_id": device_id,
}
# Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow.
self.assertCountEqual(
[[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
(f["stages"] for f in flows),
)
session = channel.json_body["session"]
# Do the registration token stage and check it has completed
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session,
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"401", channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
# Do the m.login.dummy stage and check registration was successful
params["auth"] = {
"type": LoginType.DUMMY,
"session": session,
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
"user_id": f"@{username}:{self.hs.hostname}",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0
res = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEquals(res["completed"], 1)
self.assertEquals(res["pending"], 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self):
params = {
"username": "kermit",
"password": "monkey",
}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Test with token param missing (invalid)
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"session": session,
}
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEquals(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEquals(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True})
def test_POST_registration_token_limit_uses(self):
token = "abcd"
store = self.hs.get_datastore()
# Create token that can be used once
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": 1,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
params1 = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
session2 = channel2.json_body["session"]
# Use token with session1 and check `pending` is 1
params1["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session1,
}
self.make_request(b"POST", self.url, json.dumps(params1))
# Repeat request to make sure pending isn't increased again
self.make_request(b"POST", self.url, json.dumps(params1))
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
)
self.assertEquals(pending, 1)
# Check auth fails when using token with session2
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
channel = self.make_request(b"POST", self.url, json.dumps(params2))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, json.dumps(params1))
# Check pending=0 and completed=1
res = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEquals(res["pending"], 0)
self.assertEquals(res["completed"], 1)
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, json.dumps(params2))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True})
def test_POST_registration_token_expiry(self):
token = "abcd"
now = self.hs.get_clock().time_msec()
store = self.hs.get_datastore()
# Create token that expired yesterday
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": now - 24 * 60 * 60 * 1000,
},
)
)
params = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Check authentication fails with expired token
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session,
}
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
# Update token so it expires tomorrow
self.get_success(
store.db_pool.simple_update_one(
"registration_tokens",
keyvalues={"token": token},
updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
)
)
# Check authentication succeeds
channel = self.make_request(b"POST", self.url, json.dumps(params))
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry(self):
"""Test `pending` is decremented when an uncompleted session expires."""
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
# Do 2 requests without auth to get two session IDs
params1 = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"}
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
session2 = channel2.json_body["session"]
# Use token with both sessions
params1["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session1,
}
self.make_request(b"POST", self.url, json.dumps(params1))
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
self.make_request(b"POST", self.url, json.dumps(params2))
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, json.dumps(params1))
# Check `result` of registration token stage for session1 is `True`
result1 = self.get_success(
store.db_pool.simple_select_one_onecol(
"ui_auth_sessions_credentials",
keyvalues={
"session_id": session1,
"stage_type": LoginType.REGISTRATION_TOKEN,
},
retcol="result",
)
)
self.assertTrue(db_to_json(result1))
# Check `result` for session2 is the token used
result2 = self.get_success(
store.db_pool.simple_select_one_onecol(
"ui_auth_sessions_credentials",
keyvalues={
"session_id": session2,
"stage_type": LoginType.REGISTRATION_TOKEN,
},
retcol="result",
)
)
self.assertEquals(db_to_json(result2), token)
# Delete both sessions (mimics expiry)
self.get_success(
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
)
# Check pending is now 0
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
)
self.assertEquals(pending, 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry_deleted_token(self):
"""Test session expiry doesn't break when the token is deleted.
1. Start but don't complete UIA with a registration token
2. Delete the token from the database
3. Expire the session
"""
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
# Do request without auth to get a session ID
params = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Use token
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session,
}
self.make_request(b"POST", self.url, json.dumps(params))
# Delete token
self.get_success(
store.db_pool.simple_delete_one(
"registration_tokens",
keyvalues={"token": token},
)
)
# Delete session (mimics expiry)
self.get_success(
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
)
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
def default_config(self):
config = super().default_config()
config["registration_requires_token"] = True
return config
def test_GET_token_valid(self):
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEquals(channel.json_body["valid"], True)
def test_GET_token_invalid(self):
token = "1234"
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEquals(channel.json_body["valid"], False)
@override_config(
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
)
def test_GET_ratelimiting(self):
token = "1234"
for i in range(0, 6):
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
self.assertEquals(channel.result["code"], b"200", channel.result)