Implement MSC3231: Token authenticated registration (#10142)

Signed-off-by: Callum Brown <callum@calcuode.com>

This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
Callum Brown 2021-08-21 22:14:43 +01:00 committed by GitHub
parent ecd823d766
commit 947dbbdfd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 2389 additions and 1 deletions

View File

@ -0,0 +1 @@
Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown.

View File

@ -53,6 +53,7 @@
- [Media](admin_api/media_admin_api.md) - [Media](admin_api/media_admin_api.md)
- [Purge History](admin_api/purge_history_api.md) - [Purge History](admin_api/purge_history_api.md)
- [Register Users](admin_api/register_api.md) - [Register Users](admin_api/register_api.md)
- [Registration Tokens](usage/administration/admin_api/registration_tokens.md)
- [Manipulate Room Membership](admin_api/room_membership.md) - [Manipulate Room Membership](admin_api/room_membership.md)
- [Rooms](admin_api/rooms.md) - [Rooms](admin_api/rooms.md)
- [Server Notices](admin_api/server_notices.md) - [Server Notices](admin_api/server_notices.md)

View File

@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# is using # is using
# - one for registration that ratelimits registration requests based on the # - one for registration that ratelimits registration requests based on the
# client's IP address. # client's IP address.
# - one for checking the validity of registration tokens that ratelimits
# requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP # - one for login that ratelimits login requests based on the client's IP
# address. # address.
# - one for login that ratelimits login requests based on the account the # - one for login that ratelimits login requests based on the account the
@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17 # per_second: 0.17
# burst_count: 3 # burst_count: 3
# #
#rc_registration_token_validity:
# per_second: 0.1
# burst_count: 5
#
#rc_login: #rc_login:
# address: # address:
# per_second: 0.17 # per_second: 0.17
@ -1169,6 +1175,15 @@ url_preview_accept_language:
# #
#enable_3pid_lookup: true #enable_3pid_lookup: true
# Require users to submit a token during registration.
# Tokens can be managed using the admin API:
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
# Note that `enable_registration` must be set to `true`.
# Disabling this option will not delete any tokens previously generated.
# Defaults to false. Uncomment the following to require tokens:
#
#registration_requires_token: true
# If set, allows registration of standard or admin accounts by anyone who # If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled. # has the shared secret, even if registration is otherwise disabled.
# #

View File

@ -0,0 +1,295 @@
# Registration Tokens
This API allows you to manage tokens which can be used to authenticate
registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
To use it, you will need to enable the `registration_requires_token` config
option, and authenticate by providing an `access_token` for a server admin:
see [Admin API](../../usage/administration/admin_api).
Note that this API is still experimental; not all clients may support it yet.
## Registration token objects
Most endpoints make use of JSON objects that contain details about tokens.
These objects have the following fields:
- `token`: The token which can be used to authenticate registration.
- `uses_allowed`: The number of times the token can be used to complete a
registration before it becomes invalid.
- `pending`: The number of pending uses the token has. When someone uses
the token to authenticate themselves, the pending counter is incremented
so that the token is not used more than the permitted number of times.
When the person completes registration the pending counter is decremented,
and the completed counter is incremented.
- `completed`: The number of times the token has been used to successfully
complete a registration.
- `expiry_time`: The latest time the token is valid. Given as the number of
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
To convert this into a human-readable form you can remove the milliseconds
and use the `date` command. For example, `date -d '@1625394937'`.
## List all tokens
Lists all tokens and details about them. If the request is successful, the top
level JSON object will have a `registration_tokens` key which is an array of
registration token objects.
```
GET /_synapse/admin/v1/registration_tokens
```
Optional query parameters:
- `valid`: `true` or `false`. If `true`, only valid tokens are returned.
If `false`, only tokens that have expired or have had all uses exhausted are
returned. If omitted, all tokens are returned regardless of validity.
Example:
```
GET /_synapse/admin/v1/registration_tokens
```
```
200 OK
{
"registration_tokens": [
{
"token": "abcd",
"uses_allowed": 3,
"pending": 0,
"completed": 1,
"expiry_time": null
},
{
"token": "pqrs",
"uses_allowed": 2,
"pending": 1,
"completed": 1,
"expiry_time": null
},
{
"token": "wxyz",
"uses_allowed": null,
"pending": 0,
"completed": 9,
"expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
}
]
}
```
Example using the `valid` query parameter:
```
GET /_synapse/admin/v1/registration_tokens?valid=false
```
```
200 OK
{
"registration_tokens": [
{
"token": "pqrs",
"uses_allowed": 2,
"pending": 1,
"completed": 1,
"expiry_time": null
},
{
"token": "wxyz",
"uses_allowed": null,
"pending": 0,
"completed": 9,
"expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
}
]
}
```
## Get one token
Get details about a single token. If the request is successful, the response
body will be a registration token object.
```
GET /_synapse/admin/v1/registration_tokens/<token>
```
Path parameters:
- `token`: The registration token to return details of.
Example:
```
GET /_synapse/admin/v1/registration_tokens/abcd
```
```
200 OK
{
"token": "abcd",
"uses_allowed": 3,
"pending": 0,
"completed": 1,
"expiry_time": null
}
```
## Create token
Create a new registration token. If the request is successful, the newly created
token will be returned as a registration token object in the response body.
```
POST /_synapse/admin/v1/registration_tokens/new
```
The request body must be a JSON object and can contain the following fields:
- `token`: The registration token. A string of no more than 64 characters that
consists only of characters matched by the regex `[A-Za-z0-9-_]`.
Default: randomly generated.
- `uses_allowed`: The integer number of times the token can be used to complete
a registration before it becomes invalid.
Default: `null` (unlimited uses).
- `expiry_time`: The latest time the token is valid. Given as the number of
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
You could use, for example, `date '+%s000' -d 'tomorrow'`.
Default: `null` (token does not expire).
- `length`: The length of the token randomly generated if `token` is not
specified. Must be between 1 and 64 inclusive. Default: `16`.
If a field is omitted the default is used.
Example using defaults:
```
POST /_synapse/admin/v1/registration_tokens/new
{}
```
```
200 OK
{
"token": "0M-9jbkf2t_Tgiw1",
"uses_allowed": null,
"pending": 0,
"completed": 0,
"expiry_time": null
}
```
Example specifying some fields:
```
POST /_synapse/admin/v1/registration_tokens/new
{
"token": "defg",
"uses_allowed": 1
}
```
```
200 OK
{
"token": "defg",
"uses_allowed": 1,
"pending": 0,
"completed": 0,
"expiry_time": null
}
```
## Update token
Update the number of allowed uses or expiry time of a token. If the request is
successful, the updated token will be returned as a registration token object
in the response body.
```
PUT /_synapse/admin/v1/registration_tokens/<token>
```
Path parameters:
- `token`: The registration token to update.
The request body must be a JSON object and can contain the following fields:
- `uses_allowed`: The integer number of times the token can be used to complete
a registration before it becomes invalid. By setting `uses_allowed` to `0`
the token can be easily made invalid without deleting it.
If `null` the token will have an unlimited number of uses.
- `expiry_time`: The latest time the token is valid. Given as the number of
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
If `null` the token will not expire.
If a field is omitted its value is not modified.
Example:
```
PUT /_synapse/admin/v1/registration_tokens/defg
{
"expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC
}
```
```
200 OK
{
"token": "defg",
"uses_allowed": 1,
"pending": 0,
"completed": 0,
"expiry_time": 4781243146000
}
```
## Delete token
Delete a registration token. If the request is successful, the response body
will be an empty JSON object.
```
DELETE /_synapse/admin/v1/registration_tokens/<token>
```
Path parameters:
- `token`: The registration token to delete.
Example:
```
DELETE /_synapse/admin/v1/registration_tokens/wxyz
```
```
200 OK
{}
```
## Errors
If a request fails a "standard error response" will be returned as defined in
the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards).
For example, if the token specified in a path parameter does not exist a
`404 Not Found` error will be returned.
```
GET /_synapse/admin/v1/registration_tokens/1234
```
```
404 Not Found
{
"errcode": "M_NOT_FOUND",
"error": "No such registration token: 1234"
}
```

View File

@ -236,6 +236,7 @@ expressions:
# Registration/login requests # Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(r0|unstable)/register$ ^/_matrix/client/(r0|unstable)/register$
^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
# Event sending requests # Event sending requests
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact

View File

@ -79,6 +79,7 @@ class LoginType:
TERMS = "m.login.terms" TERMS = "m.login.terms"
SSO = "m.login.sso" SSO = "m.login.sso"
DUMMY = "m.login.dummy" DUMMY = "m.login.dummy"
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
# This is used in the `type` parameter for /register when called by # This is used in the `type` parameter for /register when called by

View File

@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
ProfileRestServlet, ProfileRestServlet,
) )
from synapse.rest.client.push_rule import PushRuleRestServlet from synapse.rest.client.push_rule import PushRuleRestServlet
from synapse.rest.client.register import RegisterRestServlet from synapse.rest.client.register import (
RegisterRestServlet,
RegistrationTokenValidityRestServlet,
)
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.client.voip import VoipRestServlet from synapse.rest.client.voip import VoipRestServlet
@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer):
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource) RegisterRestServlet(self).register(resource)
RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource) login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource) DevicesRestServlet(self).register(resource)

View File

@ -79,6 +79,11 @@ class RatelimitConfig(Config):
self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
self.rc_registration_token_validity = RateLimitConfig(
config.get("rc_registration_token_validity", {}),
defaults={"per_second": 0.1, "burst_count": 5},
)
rc_login_config = config.get("rc_login", {}) rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@ -143,6 +148,8 @@ class RatelimitConfig(Config):
# is using # is using
# - one for registration that ratelimits registration requests based on the # - one for registration that ratelimits registration requests based on the
# client's IP address. # client's IP address.
# - one for checking the validity of registration tokens that ratelimits
# requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP # - one for login that ratelimits login requests based on the client's IP
# address. # address.
# - one for login that ratelimits login requests based on the account the # - one for login that ratelimits login requests based on the account the
@ -171,6 +178,10 @@ class RatelimitConfig(Config):
# per_second: 0.17 # per_second: 0.17
# burst_count: 3 # burst_count: 3
# #
#rc_registration_token_validity:
# per_second: 0.1
# burst_count: 5
#
#rc_login: #rc_login:
# address: # address:
# per_second: 0.17 # per_second: 0.17

View File

@ -33,6 +33,9 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_requires_token = config.get(
"registration_requires_token", False
)
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@ -140,6 +143,9 @@ class RegistrationConfig(Config):
"mechanism by removing the `access_token_lifetime` option." "mechanism by removing the `access_token_lifetime` option."
) )
# The fallback template used for authenticating using a registration token
self.registration_token_template = self.read_template("registration_token.html")
# The success template used during fallback auth. # The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html") self.fallback_success_template = self.read_template("auth_success.html")
@ -199,6 +205,15 @@ class RegistrationConfig(Config):
# #
#enable_3pid_lookup: true #enable_3pid_lookup: true
# Require users to submit a token during registration.
# Tokens can be managed using the admin API:
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
# Note that `enable_registration` must be set to `true`.
# Disabling this option will not delete any tokens previously generated.
# Defaults to false. Uncomment the following to require tokens:
#
#registration_requires_token: true
# If set, allows registration of standard or admin accounts by anyone who # If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled. # has the shared secret, even if registration is otherwise disabled.
# #

View File

@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
# used by validate_user_via_ui_auth to store the mxid of the user we are validating # used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for. # for.
REQUEST_USER_ID = "request_user_id" REQUEST_USER_ID = "request_user_id"
# used during registration to store the registration token used (if required) so that:
# - we can prevent a token being used twice by one session
# - we can 'use up' the token after registration has successfully completed
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"

View File

@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return await self._check_threepid("msisdn", authdict) return await self._check_threepid("msisdn", authdict)
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self._enabled = bool(hs.config.registration_requires_token)
self.store = hs.get_datastore()
def is_enabled(self) -> bool:
return self._enabled
async def check_auth(self, authdict: dict, clientip: str) -> Any:
if "token" not in authdict:
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
if not isinstance(authdict["token"], str):
raise LoginError(
400, "Registration token must be a string", Codes.INVALID_PARAM
)
if "session" not in authdict:
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
# Get these here to avoid cyclic dependencies
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
auth_handler = self.hs.get_auth_handler()
session = authdict["session"]
token = authdict["token"]
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
# return early to avoid incrementing `pending` again.
stored_token = await auth_handler.get_session_data(
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
)
if stored_token:
if token != stored_token:
raise LoginError(
400, "Registration token has changed", Codes.INVALID_PARAM
)
else:
return token
if await self.store.registration_token_is_valid(token):
# Increment pending counter, so that if token has limited uses it
# can't be used up by someone else in the meantime.
await self.store.set_registration_token_pending(token)
# Store the token in the UIA session, so that once registration
# is complete `completed` can be incremented.
await auth_handler.set_session_data(
session,
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
token,
)
# The token will be stored as the result of the authentication stage
# in ui_auth_sessions_credentials. This allows the pending counter
# for tokens to be decremented when expired sessions are deleted.
return token
else:
raise LoginError(
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
)
INTERACTIVE_AUTH_CHECKERS = [ INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker, DummyAuthChecker,
TermsAuthChecker, TermsAuthChecker,
RecaptchaAuthChecker, RecaptchaAuthChecker,
EmailIdentityAuthChecker, EmailIdentityAuthChecker,
MsisdnAuthChecker, MsisdnAuthChecker,
RegistrationTokenAuthChecker,
] ]
"""A list of UserInteractiveAuthChecker classes""" """A list of UserInteractiveAuthChecker classes"""

View File

@ -0,0 +1,23 @@
<html>
<head>
<title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
</head>
<body>
<form id="registrationForm" method="post" action="{{ myurl }}">
<div>
{% if error is defined %}
<p class="error"><strong>Error: {{ error }}</strong></p>
{% endif %}
<p>
Please enter a registration token.
</p>
<input type="hidden" name="session" value="{{ session }}" />
<input type="text" name="token" />
<input type="submit" value="Authenticate" />
</div>
</form>
</body>
</html>

View File

@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
) )
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet,
RegistrationTokenRestServlet,
)
from synapse.rest.admin.rooms import ( from synapse.rest.admin.rooms import (
DeleteRoomRestServlet, DeleteRoomRestServlet,
ForwardExtremitiesRestServlet, ForwardExtremitiesRestServlet,
@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomEventContextServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server)
ListRegistrationTokensRestServlet(hs).register(http_server)
NewRegistrationTokenRestServlet(hs).register(http_server)
RegistrationTokenRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource( def register_servlets_for_client_rest_resource(

View File

@ -0,0 +1,321 @@
# Copyright 2021 Callum Brown
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import string
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_boolean,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ListRegistrationTokensRestServlet(RestServlet):
"""List registration tokens.
To list all tokens:
GET /_synapse/admin/v1/registration_tokens
200 OK
{
"registration_tokens": [
{
"token": "abcd",
"uses_allowed": 3,
"pending": 0,
"completed": 1,
"expiry_time": null
},
{
"token": "wxyz",
"uses_allowed": null,
"pending": 0,
"completed": 9,
"expiry_time": 1625394937000
}
]
}
The optional query parameter `valid` can be used to filter the response.
If it is `true`, only valid tokens are returned. If it is `false`, only
tokens that have expired or have had all uses exhausted are returned.
If it is omitted, all tokens are returned regardless of validity.
"""
PATTERNS = admin_patterns("/registration_tokens$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
valid = parse_boolean(request, "valid")
token_list = await self.store.get_registration_tokens(valid)
return 200, {"registration_tokens": token_list}
class NewRegistrationTokenRestServlet(RestServlet):
"""Create a new registration token.
For example, to create a token specifying some fields:
POST /_synapse/admin/v1/registration_tokens/new
{
"token": "defg",
"uses_allowed": 1
}
200 OK
{
"token": "defg",
"uses_allowed": 1,
"pending": 0,
"completed": 0,
"expiry_time": null
}
Defaults are used for any fields not specified.
"""
PATTERNS = admin_patterns("/registration_tokens/new$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
# A string of all the characters allowed to be in a registration_token
self.allowed_chars = string.ascii_letters + string.digits + "-_"
self.allowed_chars_set = set(self.allowed_chars)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
if "token" in body:
token = body["token"]
if not isinstance(token, str):
raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
if not (0 < len(token) <= 64):
raise SynapseError(
400,
"token must not be empty and must not be longer than 64 characters",
Codes.INVALID_PARAM,
)
if not set(token).issubset(self.allowed_chars_set):
raise SynapseError(
400,
"token must consist only of characters matched by the regex [A-Za-z0-9-_]",
Codes.INVALID_PARAM,
)
else:
# Get length of token to generate (default is 16)
length = body.get("length", 16)
if not isinstance(length, int):
raise SynapseError(
400, "length must be an integer", Codes.INVALID_PARAM
)
if not (0 < length <= 64):
raise SynapseError(
400,
"length must be greater than zero and not greater than 64",
Codes.INVALID_PARAM,
)
# Generate token
token = await self.store.generate_registration_token(
length, self.allowed_chars
)
uses_allowed = body.get("uses_allowed", None)
if not (
uses_allowed is None
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
):
raise SynapseError(
400,
"uses_allowed must be a non-negative integer or null",
Codes.INVALID_PARAM,
)
expiry_time = body.get("expiry_time", None)
if not isinstance(expiry_time, (int, type(None))):
raise SynapseError(
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
)
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
raise SynapseError(
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
)
created = await self.store.create_registration_token(
token, uses_allowed, expiry_time
)
if not created:
raise SynapseError(
400, f"Token already exists: {token}", Codes.INVALID_PARAM
)
resp = {
"token": token,
"uses_allowed": uses_allowed,
"pending": 0,
"completed": 0,
"expiry_time": expiry_time,
}
return 200, resp
class RegistrationTokenRestServlet(RestServlet):
"""Retrieve, update, or delete the given token.
For example,
to retrieve a token:
GET /_synapse/admin/v1/registration_tokens/abcd
200 OK
{
"token": "abcd",
"uses_allowed": 3,
"pending": 0,
"completed": 1,
"expiry_time": null
}
to update a token:
PUT /_synapse/admin/v1/registration_tokens/defg
{
"uses_allowed": 5,
"expiry_time": 4781243146000
}
200 OK
{
"token": "defg",
"uses_allowed": 5,
"pending": 0,
"completed": 0,
"expiry_time": 4781243146000
}
to delete a token:
DELETE /_synapse/admin/v1/registration_tokens/wxyz
200 OK
{}
"""
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
"""Retrieve a registration token."""
await assert_requester_is_admin(self.auth, request)
token_info = await self.store.get_one_registration_token(token)
# If no result return a 404
if token_info is None:
raise NotFoundError(f"No such registration token: {token}")
return 200, token_info
async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
"""Update a registration token."""
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
new_attributes = {}
# Only add uses_allowed to new_attributes if it is present and valid
if "uses_allowed" in body:
uses_allowed = body["uses_allowed"]
if not (
uses_allowed is None
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
):
raise SynapseError(
400,
"uses_allowed must be a non-negative integer or null",
Codes.INVALID_PARAM,
)
new_attributes["uses_allowed"] = uses_allowed
if "expiry_time" in body:
expiry_time = body["expiry_time"]
if not isinstance(expiry_time, (int, type(None))):
raise SynapseError(
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
)
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
raise SynapseError(
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
)
new_attributes["expiry_time"] = expiry_time
if len(new_attributes) == 0:
# Nothing to update, get token info to return
token_info = await self.store.get_one_registration_token(token)
else:
token_info = await self.store.update_registration_token(
token, new_attributes
)
# If no result return a 404
if token_info is None:
raise NotFoundError(f"No such registration token: {token}")
return 200, token_info
async def on_DELETE(
self, request: SynapseRequest, token: str
) -> Tuple[int, JsonDict]:
"""Delete a registration token."""
await assert_requester_is_admin(self.auth, request)
if await self.store.delete_registration_token(token):
return 200, {}
raise NotFoundError(f"No such registration token: {token}")

View File

@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template self.terms_template = hs.config.terms_template
self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template self.success_template = hs.config.fallback_success_template
async def on_GET(self, request, stagetype): async def on_GET(self, request, stagetype):
@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider. # re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session) html = await self.auth_handler.start_sso_ui_auth(request, session)
elif stagetype == LoginType.REGISTRATION_TOKEN:
html = self.registration_token_template.render(
session=session,
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
)
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")
@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
# The SSO fallback workflow should not post here, # The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.") raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
elif stagetype == LoginType.REGISTRATION_TOKEN:
token = parse_string(request, "token", required=True)
authdict = {"session": session, "token": token}
try:
await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
)
except LoginError as e:
html = self.registration_token_template.render(
session=session,
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
error=e.msg,
)
else:
html = self.success_template.render()
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")

View File

@ -28,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError, ThreepidValidationError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig from synapse.config.consent import ConsentConfig
@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
return 200, {"available": True} return 200, {"available": True}
class RegistrationTokenValidityRestServlet(RestServlet):
"""Check the validity of a registration token.
Example:
GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
200 OK
{
"valid": true
}
"""
PATTERNS = client_patterns(
f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
releases=(),
unstable=True,
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
self.ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
)
async def on_GET(self, request):
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
token = parse_string(request, "token", required=True)
valid = await self.store.registration_token_is_valid(token)
return 200, {"valid": valid}
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$") PATTERNS = client_patterns("/register$")
@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
) )
if registered: if registered:
# Check if a token was used to authenticate registration
registration_token = await self.auth_handler.get_session_data(
session_id,
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
)
if registration_token:
# Increment the `completed` counter for the token
await self.store.use_registration_token(registration_token)
# Indicate that the token has been successfully used so that
# pending is not decremented again when expiring old UIA sessions.
await self.store.mark_ui_auth_stage_complete(
session_id,
LoginType.REGISTRATION_TOKEN,
True,
)
await self.registration_handler.post_registration_actions( await self.registration_handler.post_registration_actions(
user_id=registered_user_id, user_id=registered_user_id,
auth_result=auth_result, auth_result=auth_result,
@ -868,6 +934,11 @@ def _calculate_registration_flows(
for flow in flows: for flow in flows:
flow.insert(0, LoginType.RECAPTCHA) flow.insert(0, LoginType.RECAPTCHA)
# Prepend registration token to all flows if we're requiring a token
if config.registration_requires_token:
for flow in flows:
flow.insert(0, LoginType.REGISTRATION_TOKEN)
return flows return flows
@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server) RegistrationSubmitTokenServlet(hs).register(http_server)
RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated", desc="update_access_token_last_validated",
) )
async def registration_token_is_valid(self, token: str) -> bool:
"""Checks if a token can be used to authenticate a registration.
Args:
token: The registration token to be checked
Returns:
True if the token is valid, False otherwise.
"""
res = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
)
# Check if the token exists
if res is None:
return False
# Check if the token has expired
now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now:
return False
# Check if the token has been used up
if (
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
return False
# Otherwise, the token is valid
return True
async def set_registration_token_pending(self, token: str) -> None:
"""Increment the pending registrations counter for a token.
Args:
token: The registration token pending use
"""
def _set_registration_token_pending_txn(txn):
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues={"pending": pending + 1},
)
return await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
async def use_registration_token(self, token: str) -> None:
"""Complete a use of the given registration token.
The `pending` counter will be decremented, and the `completed`
counter will be incremented.
Args:
token: The registration token to be 'used'
"""
def _use_registration_token_txn(txn):
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
) # type: ignore
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues={
"completed": res["completed"] + 1,
"pending": res["pending"] - 1,
},
)
return await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
async def get_registration_tokens(
self, valid: Optional[bool] = None
) -> List[Dict[str, Any]]:
"""List all registration tokens. Used by the admin API.
Args:
valid: If True, only valid tokens are returned.
If False, only invalid tokens are returned.
Default is None: return all tokens regardless of validity.
Returns:
A list of dicts, each containing details of a token.
"""
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
elif valid:
# Select valid tokens only
sql = (
"SELECT * FROM registration_tokens WHERE "
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
"AND (expiry_time > ? OR expiry_time IS NULL)"
)
txn.execute(sql, [now])
else:
# Select invalid tokens only
sql = (
"SELECT * FROM registration_tokens WHERE "
"uses_allowed <= pending + completed OR expiry_time <= ?"
)
txn.execute(sql, [now])
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"select_registration_tokens",
select_registration_tokens_txn,
self._clock.time_msec(),
valid,
)
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Get info about the given registration token. Used by the admin API.
Args:
token: The token to retrieve information about.
Returns:
A dict, or None if token doesn't exist.
"""
return await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
async def generate_registration_token(
self, length: int, chars: str
) -> Optional[str]:
"""Generate a random registration token. Used by the admin API.
Args:
length: The length of the token to generate.
chars: A string of the characters allowed in the generated token.
Returns:
The generated token.
Raises:
SynapseError if a unique registration token could still not be
generated after a few tries.
"""
# Make a few attempts at generating a unique token of the required
# length before failing.
for _i in range(3):
# Generate token
token = "".join(random.choices(chars, k=length))
# Check if the token already exists
existing_token = await self.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="token",
allow_none=True,
desc="check_if_registration_token_exists",
)
if existing_token is None:
# The generated token doesn't exist yet, return it
return token
raise SynapseError(
500,
"Unable to generate a unique registration token. Try again with a greater length",
Codes.UNKNOWN,
)
async def create_registration_token(
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
) -> bool:
"""Create a new registration token. Used by the admin API.
Args:
token: The token to create.
uses_allowed: The number of times the token can be used to complete
a registration before it becomes invalid. A value of None indicates
unlimited uses.
expiry_time: The latest time the token is valid. Given as the
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
None indicates that the token does not expire.
Returns:
Whether the row was inserted or not.
"""
def _create_registration_token_txn(txn):
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["token"],
allow_none=True,
)
if row is not None:
# Token already exists
return False
self.db_pool.simple_insert_txn(
txn,
"registration_tokens",
values={
"token": token,
"uses_allowed": uses_allowed,
"pending": 0,
"completed": 0,
"expiry_time": expiry_time,
},
)
return True
return await self.db_pool.runInteraction(
"create_registration_token", _create_registration_token_txn
)
async def update_registration_token(
self, token: str, updatevalues: Dict[str, Optional[int]]
) -> Optional[Dict[str, Any]]:
"""Update a registration token. Used by the admin API.
Args:
token: The token to update.
updatevalues: A dict with the fields to update. E.g.:
`{"uses_allowed": 3}` to update just uses_allowed, or
`{"uses_allowed": 3, "expiry_time": None}` to update both.
This is passed straight to simple_update_one.
Returns:
A dict with all info about the token, or None if token doesn't exist.
"""
def _update_registration_token_txn(txn):
try:
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues=updatevalues,
)
except StoreError:
# Update failed because token does not exist
return None
# Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=[
"token",
"uses_allowed",
"pending",
"completed",
"expiry_time",
],
allow_none=True,
)
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
async def delete_registration_token(self, token: str) -> bool:
"""Delete a registration token. Used by the admin API.
Args:
token: The token to delete.
Returns:
Whether the token was successfully deleted or not.
"""
try:
await self.db_pool.simple_delete_one(
"registration_tokens",
keyvalues={"token": token},
desc="delete_registration_token",
)
except StoreError:
# Deletion failed because token does not exist
return False
return True
@cached() @cached()
async def mark_access_token_as_used(self, token_id: int) -> None: async def mark_access_token_as_used(self, token_id: int) -> None:
""" """

View File

@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr import attr
from synapse.api.constants import LoginType
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={}, keyvalues={},
) )
# If a registration token was used, decrement the pending counter
# before deleting the session.
rows = self.db_pool.simple_select_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
iterable=session_ids,
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
retcols=["result"],
)
# Get the tokens used and how much pending needs to be decremented by.
token_counts: Dict[str, int] = {}
for r in rows:
# If registration was successfully completed, the result of the
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
token = db_to_json(r["result"])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
token_rows = self.db_pool.simple_select_many_txn(
txn,
table="registration_tokens",
column="token",
iterable=list(token_counts.keys()),
keyvalues={},
retcols=["token", "pending"],
)
for token_row in token_rows:
token = token_row["token"]
new_pending = token_row["pending"] - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",
keyvalues={"token": token},
updatevalues={"pending": new_pending},
)
# Delete the corresponding completed credentials. # Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(
txn, txn,

View File

@ -0,0 +1,23 @@
/* Copyright 2021 Callum Brown
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS registration_tokens(
token TEXT NOT NULL, -- The token that can be used for authentication.
uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
pending INT NOT NULL, -- The number of in progress registrations using this token.
completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
UNIQUE (token)
);

View File

@ -0,0 +1,710 @@
# Copyright 2021 Callum Brown
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import string
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
from tests import unittest
class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_tok = self.login("user", "pass")
self.url = "/_synapse/admin/v1/registration_tokens"
def _new_token(self, **kwargs):
"""Helper function to create a token."""
token = kwargs.get(
"token",
"".join(random.choices(string.ascii_letters, k=8)),
)
self.get_success(
self.store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": kwargs.get("uses_allowed", None),
"pending": kwargs.get("pending", 0),
"completed": kwargs.get("completed", 0),
"expiry_time": kwargs.get("expiry_time", None),
},
)
)
return token
# CREATION
def test_create_no_auth(self):
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self):
"""Try to create a token while not an admin."""
channel = self.make_request(
"POST",
self.url + "/new",
{},
access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self):
"""Create a token using all the defaults."""
channel = self.make_request(
"POST",
self.url + "/new",
{},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
def test_create_specifying_fields(self):
"""Create a token specifying the value of all fields."""
data = {
"token": "abcd",
"uses_allowed": 1,
"expiry_time": self.clock.time_msec() + 1000000,
}
channel = self.make_request(
"POST",
self.url + "/new",
data,
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["token"], "abcd")
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
def test_create_with_null_value(self):
"""Create a token specifying unlimited uses and no expiry."""
data = {
"uses_allowed": None,
"expiry_time": None,
}
channel = self.make_request(
"POST",
self.url + "/new",
data,
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
def test_create_token_too_long(self):
"""Check token longer than 64 chars is invalid."""
data = {"token": "a" * 65}
channel = self.make_request(
"POST",
self.url + "/new",
data,
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self):
"""Check you can't create token with invalid characters."""
data = {
"token": "abc/def",
}
channel = self.make_request(
"POST",
self.url + "/new",
data,
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self):
"""Check you can't create token that already exists."""
data = {
"token": "abcd",
}
channel1 = self.make_request(
"POST",
self.url + "/new",
data,
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
channel2 = self.make_request(
"POST",
self.url + "/new",
data,
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self):
"""Check right error is raised when server can't generate unique token."""
# Create all possible single character tokens
tokens = []
for c in string.ascii_letters + string.digits + "-_":
tokens.append(
{
"token": c,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
}
)
self.get_success(
self.store.db_pool.simple_insert_many(
"registration_tokens",
tokens,
"create_all_registration_tokens",
)
)
# Check creating a single character token fails with a 500 status code
channel = self.make_request(
"POST",
self.url + "/new",
{"length": 1},
access_token=self.admin_user_tok,
)
self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
def test_create_uses_allowed(self):
"""Check you can only create a token with good values for uses_allowed."""
# Should work with 0 (token is invalid from the start)
channel = self.make_request(
"POST",
self.url + "/new",
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer
channel = self.make_request(
"POST",
self.url + "/new",
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
channel = self.make_request(
"POST",
self.url + "/new",
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self):
"""Check you can't create a token with an invalid expiry_time."""
# Should fail with a time in the past
channel = self.make_request(
"POST",
self.url + "/new",
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
channel = self.make_request(
"POST",
self.url + "/new",
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self):
"""Check you can only generate a token with a valid length."""
# Should work with 64
channel = self.make_request(
"POST",
self.url + "/new",
{"length": 64},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0
channel = self.make_request(
"POST",
self.url + "/new",
{"length": 0},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
channel = self.make_request(
"POST",
self.url + "/new",
{"length": -5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
channel = self.make_request(
"POST",
self.url + "/new",
{"length": 8.5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
channel = self.make_request(
"POST",
self.url + "/new",
{"length": 65},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
def test_update_no_auth(self):
"""Try to update a token without authentication."""
channel = self.make_request(
"PUT",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self):
"""Try to update a token while not an admin."""
channel = self.make_request(
"PUT",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self):
"""Try to update a token that doesn't exist."""
channel = self.make_request(
"PUT",
self.url + "/1234",
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self):
"""Test updating just uses_allowed."""
# Create new token using default values
token = self._new_token()
# Should succeed with 1
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"])
# Should succeed with 0 (makes token invalid)
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"])
# Should succeed with null
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"uses_allowed": None},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
# Should fail with a float
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self):
"""Test updating just expiry_time."""
# Create new token using default values
token = self._new_token()
new_expiry_time = self.clock.time_msec() + 1000000
# Should succeed with a time in the future
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"expiry_time": new_expiry_time},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"])
# Should succeed with null
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"expiry_time": None},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"])
# Should fail with a time in the past
past_time = self.clock.time_msec() - 10000
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
channel = self.make_request(
"PUT",
self.url + "/" + token,
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self):
"""Test updating both uses_allowed and expiry_time."""
# Create new token using default values
token = self._new_token()
new_expiry_time = self.clock.time_msec() + 1000000
data = {
"uses_allowed": 1,
"expiry_time": new_expiry_time,
}
channel = self.make_request(
"PUT",
self.url + "/" + token,
data,
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
def test_update_invalid_type(self):
"""Test using invalid types doesn't work."""
# Create new token using default values
token = self._new_token()
data = {
"uses_allowed": False,
"expiry_time": "1626430124000",
}
channel = self.make_request(
"PUT",
self.url + "/" + token,
data,
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
def test_delete_no_auth(self):
"""Try to delete a token without authentication."""
channel = self.make_request(
"DELETE",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self):
"""Try to delete a token while not an admin."""
channel = self.make_request(
"DELETE",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self):
"""Try to delete a token that doesn't exist."""
channel = self.make_request(
"DELETE",
self.url + "/1234",
{},
access_token=self.admin_user_tok,
)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self):
"""Test deleting a token."""
# Create new token using default values
token = self._new_token()
channel = self.make_request(
"DELETE",
self.url + "/" + token,
{},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# GETTING ONE
def test_get_no_auth(self):
"""Try to get a token without authentication."""
channel = self.make_request(
"GET",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self):
"""Try to get a token while not an admin."""
channel = self.make_request(
"GET",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_get_non_existent(self):
"""Try to get a token that doesn't exist."""
channel = self.make_request(
"GET",
self.url + "/1234",
{},
access_token=self.admin_user_tok,
)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self):
"""Test getting a token."""
# Create new token using default values
token = self._new_token()
channel = self.make_request(
"GET",
self.url + "/" + token,
{},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
# LISTING
def test_list_no_auth(self):
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self):
"""Try to list tokens while not an admin."""
channel = self.make_request(
"GET",
self.url,
{},
access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self):
"""Test listing all tokens."""
# Create new token using default values
token = self._new_token()
channel = self.make_request(
"GET",
self.url,
{},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token)
self.assertIsNone(token_info["uses_allowed"])
self.assertIsNone(token_info["expiry_time"])
self.assertEqual(token_info["pending"], 0)
self.assertEqual(token_info["completed"], 0)
def test_list_invalid_query_parameter(self):
"""Test with `valid` query parameter not `true` or `false`."""
channel = self.make_request(
"GET",
self.url + "?valid=x",
{},
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
def _test_list_query_parameter(self, valid: str):
"""Helper used to test both valid=true and valid=false."""
# Create 2 valid and 2 invalid tokens.
now = self.hs.get_clock().time_msec()
# Create always valid token
valid1 = self._new_token()
# Create token that hasn't been used up
valid2 = self._new_token(uses_allowed=1)
# Create token that has expired
invalid1 = self._new_token(expiry_time=now - 10000)
# Create token that has been used up but hasn't expired
invalid2 = self._new_token(
uses_allowed=2,
pending=1,
completed=1,
expiry_time=now + 1000000,
)
if valid == "true":
tokens = [valid1, valid2]
else:
tokens = [invalid1, invalid2]
channel = self.make_request(
"GET",
self.url + "?valid=" + valid,
{},
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1]
self.assertIn(token_info_1["token"], tokens)
self.assertIn(token_info_2["token"], tokens)
def test_list_valid(self):
"""Test listing just valid tokens."""
self._test_list_query_parameter(valid="true")
def test_list_invalid(self):
"""Test listing just invalid tokens."""
self._test_list_query_parameter(valid="false")

View File

@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync from synapse.rest.client import account, account_validity, login, logout, register, sync
from synapse.storage._base import db_to_json
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self):
username = "kermit"
device_id = "frogfone"
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
params = {
"username": username,
"password": "monkey",
"device_id": device_id,
}
# Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow.
self.assertCountEqual(
[[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
(f["stages"] for f in flows),
)
session = channel.json_body["session"]
# Do the registration token stage and check it has completed
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session,
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"401", channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
# Do the m.login.dummy stage and check registration was successful
params["auth"] = {
"type": LoginType.DUMMY,
"session": session,
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
"user_id": f"@{username}:{self.hs.hostname}",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0
res = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEquals(res["completed"], 1)
self.assertEquals(res["pending"], 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self):
params = {
"username": "kermit",
"password": "monkey",
}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Test with token param missing (invalid)
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"session": session,
}
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEquals(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEquals(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True})
def test_POST_registration_token_limit_uses(self):
token = "abcd"
store = self.hs.get_datastore()
# Create token that can be used once
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": 1,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
params1 = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
session2 = channel2.json_body["session"]
# Use token with session1 and check `pending` is 1
params1["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session1,
}
self.make_request(b"POST", self.url, json.dumps(params1))
# Repeat request to make sure pending isn't increased again
self.make_request(b"POST", self.url, json.dumps(params1))
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
)
self.assertEquals(pending, 1)
# Check auth fails when using token with session2
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
channel = self.make_request(b"POST", self.url, json.dumps(params2))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, json.dumps(params1))
# Check pending=0 and completed=1
res = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEquals(res["pending"], 0)
self.assertEquals(res["completed"], 1)
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, json.dumps(params2))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True})
def test_POST_registration_token_expiry(self):
token = "abcd"
now = self.hs.get_clock().time_msec()
store = self.hs.get_datastore()
# Create token that expired yesterday
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": now - 24 * 60 * 60 * 1000,
},
)
)
params = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Check authentication fails with expired token
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session,
}
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEquals(channel.json_body["completed"], [])
# Update token so it expires tomorrow
self.get_success(
store.db_pool.simple_update_one(
"registration_tokens",
keyvalues={"token": token},
updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
)
)
# Check authentication succeeds
channel = self.make_request(b"POST", self.url, json.dumps(params))
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry(self):
"""Test `pending` is decremented when an uncompleted session expires."""
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
# Do 2 requests without auth to get two session IDs
params1 = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"}
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
session2 = channel2.json_body["session"]
# Use token with both sessions
params1["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session1,
}
self.make_request(b"POST", self.url, json.dumps(params1))
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
self.make_request(b"POST", self.url, json.dumps(params2))
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, json.dumps(params1))
# Check `result` of registration token stage for session1 is `True`
result1 = self.get_success(
store.db_pool.simple_select_one_onecol(
"ui_auth_sessions_credentials",
keyvalues={
"session_id": session1,
"stage_type": LoginType.REGISTRATION_TOKEN,
},
retcol="result",
)
)
self.assertTrue(db_to_json(result1))
# Check `result` for session2 is the token used
result2 = self.get_success(
store.db_pool.simple_select_one_onecol(
"ui_auth_sessions_credentials",
keyvalues={
"session_id": session2,
"stage_type": LoginType.REGISTRATION_TOKEN,
},
retcol="result",
)
)
self.assertEquals(db_to_json(result2), token)
# Delete both sessions (mimics expiry)
self.get_success(
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
)
# Check pending is now 0
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
)
self.assertEquals(pending, 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry_deleted_token(self):
"""Test session expiry doesn't break when the token is deleted.
1. Start but don't complete UIA with a registration token
2. Delete the token from the database
3. Expire the session
"""
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
# Do request without auth to get a session ID
params = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Use token
params["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session,
}
self.make_request(b"POST", self.url, json.dumps(params))
# Delete token
self.get_success(
store.db_pool.simple_delete_one(
"registration_tokens",
keyvalues={"token": token},
)
)
# Delete session (mimics expiry)
self.get_success(
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
)
def test_advertised_flows(self): def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.result["code"], b"401", channel.result)
@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period) self.assertLessEqual(res, now_ms + self.validity_period)
class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
def default_config(self):
config = super().default_config()
config["registration_requires_token"] = True
return config
def test_GET_token_valid(self):
token = "abcd"
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
"uses_allowed": None,
"pending": 0,
"completed": 0,
"expiry_time": None,
},
)
)
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEquals(channel.json_body["valid"], True)
def test_GET_token_invalid(self):
token = "1234"
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEquals(channel.json_body["valid"], False)
@override_config(
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
)
def test_GET_ratelimiting(self):
token = "1234"
for i in range(0, 6):
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
)
self.assertEquals(channel.result["code"], b"200", channel.result)