mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-16 09:17:10 -05:00
Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com> This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
parent
ecd823d766
commit
947dbbdfd1
1
changelog.d/10142.feature
Normal file
1
changelog.d/10142.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown.
|
@ -53,6 +53,7 @@
|
|||||||
- [Media](admin_api/media_admin_api.md)
|
- [Media](admin_api/media_admin_api.md)
|
||||||
- [Purge History](admin_api/purge_history_api.md)
|
- [Purge History](admin_api/purge_history_api.md)
|
||||||
- [Register Users](admin_api/register_api.md)
|
- [Register Users](admin_api/register_api.md)
|
||||||
|
- [Registration Tokens](usage/administration/admin_api/registration_tokens.md)
|
||||||
- [Manipulate Room Membership](admin_api/room_membership.md)
|
- [Manipulate Room Membership](admin_api/room_membership.md)
|
||||||
- [Rooms](admin_api/rooms.md)
|
- [Rooms](admin_api/rooms.md)
|
||||||
- [Server Notices](admin_api/server_notices.md)
|
- [Server Notices](admin_api/server_notices.md)
|
||||||
|
@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
|
|||||||
# is using
|
# is using
|
||||||
# - one for registration that ratelimits registration requests based on the
|
# - one for registration that ratelimits registration requests based on the
|
||||||
# client's IP address.
|
# client's IP address.
|
||||||
|
# - one for checking the validity of registration tokens that ratelimits
|
||||||
|
# requests based on the client's IP address.
|
||||||
# - one for login that ratelimits login requests based on the client's IP
|
# - one for login that ratelimits login requests based on the client's IP
|
||||||
# address.
|
# address.
|
||||||
# - one for login that ratelimits login requests based on the account the
|
# - one for login that ratelimits login requests based on the account the
|
||||||
@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
|
|||||||
# per_second: 0.17
|
# per_second: 0.17
|
||||||
# burst_count: 3
|
# burst_count: 3
|
||||||
#
|
#
|
||||||
|
#rc_registration_token_validity:
|
||||||
|
# per_second: 0.1
|
||||||
|
# burst_count: 5
|
||||||
|
#
|
||||||
#rc_login:
|
#rc_login:
|
||||||
# address:
|
# address:
|
||||||
# per_second: 0.17
|
# per_second: 0.17
|
||||||
@ -1169,6 +1175,15 @@ url_preview_accept_language:
|
|||||||
#
|
#
|
||||||
#enable_3pid_lookup: true
|
#enable_3pid_lookup: true
|
||||||
|
|
||||||
|
# Require users to submit a token during registration.
|
||||||
|
# Tokens can be managed using the admin API:
|
||||||
|
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
|
||||||
|
# Note that `enable_registration` must be set to `true`.
|
||||||
|
# Disabling this option will not delete any tokens previously generated.
|
||||||
|
# Defaults to false. Uncomment the following to require tokens:
|
||||||
|
#
|
||||||
|
#registration_requires_token: true
|
||||||
|
|
||||||
# If set, allows registration of standard or admin accounts by anyone who
|
# If set, allows registration of standard or admin accounts by anyone who
|
||||||
# has the shared secret, even if registration is otherwise disabled.
|
# has the shared secret, even if registration is otherwise disabled.
|
||||||
#
|
#
|
||||||
|
295
docs/usage/administration/admin_api/registration_tokens.md
Normal file
295
docs/usage/administration/admin_api/registration_tokens.md
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
# Registration Tokens
|
||||||
|
|
||||||
|
This API allows you to manage tokens which can be used to authenticate
|
||||||
|
registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
|
||||||
|
To use it, you will need to enable the `registration_requires_token` config
|
||||||
|
option, and authenticate by providing an `access_token` for a server admin:
|
||||||
|
see [Admin API](../../usage/administration/admin_api).
|
||||||
|
Note that this API is still experimental; not all clients may support it yet.
|
||||||
|
|
||||||
|
|
||||||
|
## Registration token objects
|
||||||
|
|
||||||
|
Most endpoints make use of JSON objects that contain details about tokens.
|
||||||
|
These objects have the following fields:
|
||||||
|
- `token`: The token which can be used to authenticate registration.
|
||||||
|
- `uses_allowed`: The number of times the token can be used to complete a
|
||||||
|
registration before it becomes invalid.
|
||||||
|
- `pending`: The number of pending uses the token has. When someone uses
|
||||||
|
the token to authenticate themselves, the pending counter is incremented
|
||||||
|
so that the token is not used more than the permitted number of times.
|
||||||
|
When the person completes registration the pending counter is decremented,
|
||||||
|
and the completed counter is incremented.
|
||||||
|
- `completed`: The number of times the token has been used to successfully
|
||||||
|
complete a registration.
|
||||||
|
- `expiry_time`: The latest time the token is valid. Given as the number of
|
||||||
|
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
|
||||||
|
To convert this into a human-readable form you can remove the milliseconds
|
||||||
|
and use the `date` command. For example, `date -d '@1625394937'`.
|
||||||
|
|
||||||
|
|
||||||
|
## List all tokens
|
||||||
|
|
||||||
|
Lists all tokens and details about them. If the request is successful, the top
|
||||||
|
level JSON object will have a `registration_tokens` key which is an array of
|
||||||
|
registration token objects.
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/registration_tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional query parameters:
|
||||||
|
- `valid`: `true` or `false`. If `true`, only valid tokens are returned.
|
||||||
|
If `false`, only tokens that have expired or have had all uses exhausted are
|
||||||
|
returned. If omitted, all tokens are returned regardless of validity.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/registration_tokens
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"registration_tokens": [
|
||||||
|
{
|
||||||
|
"token": "abcd",
|
||||||
|
"uses_allowed": 3,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 1,
|
||||||
|
"expiry_time": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token": "pqrs",
|
||||||
|
"uses_allowed": 2,
|
||||||
|
"pending": 1,
|
||||||
|
"completed": 1,
|
||||||
|
"expiry_time": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token": "wxyz",
|
||||||
|
"uses_allowed": null,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 9,
|
||||||
|
"expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example using the `valid` query parameter:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/registration_tokens?valid=false
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"registration_tokens": [
|
||||||
|
{
|
||||||
|
"token": "pqrs",
|
||||||
|
"uses_allowed": 2,
|
||||||
|
"pending": 1,
|
||||||
|
"completed": 1,
|
||||||
|
"expiry_time": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token": "wxyz",
|
||||||
|
"uses_allowed": null,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 9,
|
||||||
|
"expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Get one token
|
||||||
|
|
||||||
|
Get details about a single token. If the request is successful, the response
|
||||||
|
body will be a registration token object.
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/registration_tokens/<token>
|
||||||
|
```
|
||||||
|
|
||||||
|
Path parameters:
|
||||||
|
- `token`: The registration token to return details of.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/registration_tokens/abcd
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "abcd",
|
||||||
|
"uses_allowed": 3,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 1,
|
||||||
|
"expiry_time": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Create token
|
||||||
|
|
||||||
|
Create a new registration token. If the request is successful, the newly created
|
||||||
|
token will be returned as a registration token object in the response body.
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /_synapse/admin/v1/registration_tokens/new
|
||||||
|
```
|
||||||
|
|
||||||
|
The request body must be a JSON object and can contain the following fields:
|
||||||
|
- `token`: The registration token. A string of no more than 64 characters that
|
||||||
|
consists only of characters matched by the regex `[A-Za-z0-9-_]`.
|
||||||
|
Default: randomly generated.
|
||||||
|
- `uses_allowed`: The integer number of times the token can be used to complete
|
||||||
|
a registration before it becomes invalid.
|
||||||
|
Default: `null` (unlimited uses).
|
||||||
|
- `expiry_time`: The latest time the token is valid. Given as the number of
|
||||||
|
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
|
||||||
|
You could use, for example, `date '+%s000' -d 'tomorrow'`.
|
||||||
|
Default: `null` (token does not expire).
|
||||||
|
- `length`: The length of the token randomly generated if `token` is not
|
||||||
|
specified. Must be between 1 and 64 inclusive. Default: `16`.
|
||||||
|
|
||||||
|
If a field is omitted the default is used.
|
||||||
|
|
||||||
|
Example using defaults:
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /_synapse/admin/v1/registration_tokens/new
|
||||||
|
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "0M-9jbkf2t_Tgiw1",
|
||||||
|
"uses_allowed": null,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example specifying some fields:
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /_synapse/admin/v1/registration_tokens/new
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "defg",
|
||||||
|
"uses_allowed": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "defg",
|
||||||
|
"uses_allowed": 1,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Update token
|
||||||
|
|
||||||
|
Update the number of allowed uses or expiry time of a token. If the request is
|
||||||
|
successful, the updated token will be returned as a registration token object
|
||||||
|
in the response body.
|
||||||
|
|
||||||
|
```
|
||||||
|
PUT /_synapse/admin/v1/registration_tokens/<token>
|
||||||
|
```
|
||||||
|
|
||||||
|
Path parameters:
|
||||||
|
- `token`: The registration token to update.
|
||||||
|
|
||||||
|
The request body must be a JSON object and can contain the following fields:
|
||||||
|
- `uses_allowed`: The integer number of times the token can be used to complete
|
||||||
|
a registration before it becomes invalid. By setting `uses_allowed` to `0`
|
||||||
|
the token can be easily made invalid without deleting it.
|
||||||
|
If `null` the token will have an unlimited number of uses.
|
||||||
|
- `expiry_time`: The latest time the token is valid. Given as the number of
|
||||||
|
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
|
||||||
|
If `null` the token will not expire.
|
||||||
|
|
||||||
|
If a field is omitted its value is not modified.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
PUT /_synapse/admin/v1/registration_tokens/defg
|
||||||
|
|
||||||
|
{
|
||||||
|
"expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "defg",
|
||||||
|
"uses_allowed": 1,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": 4781243146000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Delete token
|
||||||
|
|
||||||
|
Delete a registration token. If the request is successful, the response body
|
||||||
|
will be an empty JSON object.
|
||||||
|
|
||||||
|
```
|
||||||
|
DELETE /_synapse/admin/v1/registration_tokens/<token>
|
||||||
|
```
|
||||||
|
|
||||||
|
Path parameters:
|
||||||
|
- `token`: The registration token to delete.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
DELETE /_synapse/admin/v1/registration_tokens/wxyz
|
||||||
|
```
|
||||||
|
```
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Errors
|
||||||
|
|
||||||
|
If a request fails a "standard error response" will be returned as defined in
|
||||||
|
the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards).
|
||||||
|
|
||||||
|
For example, if the token specified in a path parameter does not exist a
|
||||||
|
`404 Not Found` error will be returned.
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/registration_tokens/1234
|
||||||
|
```
|
||||||
|
```
|
||||||
|
404 Not Found
|
||||||
|
|
||||||
|
{
|
||||||
|
"errcode": "M_NOT_FOUND",
|
||||||
|
"error": "No such registration token: 1234"
|
||||||
|
}
|
||||||
|
```
|
@ -236,6 +236,7 @@ expressions:
|
|||||||
# Registration/login requests
|
# Registration/login requests
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/login$
|
^/_matrix/client/(api/v1|r0|unstable)/login$
|
||||||
^/_matrix/client/(r0|unstable)/register$
|
^/_matrix/client/(r0|unstable)/register$
|
||||||
|
^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
|
||||||
|
|
||||||
# Event sending requests
|
# Event sending requests
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
|
||||||
|
@ -79,6 +79,7 @@ class LoginType:
|
|||||||
TERMS = "m.login.terms"
|
TERMS = "m.login.terms"
|
||||||
SSO = "m.login.sso"
|
SSO = "m.login.sso"
|
||||||
DUMMY = "m.login.dummy"
|
DUMMY = "m.login.dummy"
|
||||||
|
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
|
||||||
|
|
||||||
|
|
||||||
# This is used in the `type` parameter for /register when called by
|
# This is used in the `type` parameter for /register when called by
|
||||||
|
@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
|
|||||||
ProfileRestServlet,
|
ProfileRestServlet,
|
||||||
)
|
)
|
||||||
from synapse.rest.client.push_rule import PushRuleRestServlet
|
from synapse.rest.client.push_rule import PushRuleRestServlet
|
||||||
from synapse.rest.client.register import RegisterRestServlet
|
from synapse.rest.client.register import (
|
||||||
|
RegisterRestServlet,
|
||||||
|
RegistrationTokenValidityRestServlet,
|
||||||
|
)
|
||||||
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
|
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
|
||||||
from synapse.rest.client.versions import VersionsRestServlet
|
from synapse.rest.client.versions import VersionsRestServlet
|
||||||
from synapse.rest.client.voip import VoipRestServlet
|
from synapse.rest.client.voip import VoipRestServlet
|
||||||
@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer):
|
|||||||
resource = JsonResource(self, canonical_json=False)
|
resource = JsonResource(self, canonical_json=False)
|
||||||
|
|
||||||
RegisterRestServlet(self).register(resource)
|
RegisterRestServlet(self).register(resource)
|
||||||
|
RegistrationTokenValidityRestServlet(self).register(resource)
|
||||||
login.register_servlets(self, resource)
|
login.register_servlets(self, resource)
|
||||||
ThreepidRestServlet(self).register(resource)
|
ThreepidRestServlet(self).register(resource)
|
||||||
DevicesRestServlet(self).register(resource)
|
DevicesRestServlet(self).register(resource)
|
||||||
|
@ -79,6 +79,11 @@ class RatelimitConfig(Config):
|
|||||||
|
|
||||||
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
|
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
|
||||||
|
|
||||||
|
self.rc_registration_token_validity = RateLimitConfig(
|
||||||
|
config.get("rc_registration_token_validity", {}),
|
||||||
|
defaults={"per_second": 0.1, "burst_count": 5},
|
||||||
|
)
|
||||||
|
|
||||||
rc_login_config = config.get("rc_login", {})
|
rc_login_config = config.get("rc_login", {})
|
||||||
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
|
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
|
||||||
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
|
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
|
||||||
@ -143,6 +148,8 @@ class RatelimitConfig(Config):
|
|||||||
# is using
|
# is using
|
||||||
# - one for registration that ratelimits registration requests based on the
|
# - one for registration that ratelimits registration requests based on the
|
||||||
# client's IP address.
|
# client's IP address.
|
||||||
|
# - one for checking the validity of registration tokens that ratelimits
|
||||||
|
# requests based on the client's IP address.
|
||||||
# - one for login that ratelimits login requests based on the client's IP
|
# - one for login that ratelimits login requests based on the client's IP
|
||||||
# address.
|
# address.
|
||||||
# - one for login that ratelimits login requests based on the account the
|
# - one for login that ratelimits login requests based on the account the
|
||||||
@ -171,6 +178,10 @@ class RatelimitConfig(Config):
|
|||||||
# per_second: 0.17
|
# per_second: 0.17
|
||||||
# burst_count: 3
|
# burst_count: 3
|
||||||
#
|
#
|
||||||
|
#rc_registration_token_validity:
|
||||||
|
# per_second: 0.1
|
||||||
|
# burst_count: 5
|
||||||
|
#
|
||||||
#rc_login:
|
#rc_login:
|
||||||
# address:
|
# address:
|
||||||
# per_second: 0.17
|
# per_second: 0.17
|
||||||
|
@ -33,6 +33,9 @@ class RegistrationConfig(Config):
|
|||||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||||
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
||||||
|
self.registration_requires_token = config.get(
|
||||||
|
"registration_requires_token", False
|
||||||
|
)
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
@ -140,6 +143,9 @@ class RegistrationConfig(Config):
|
|||||||
"mechanism by removing the `access_token_lifetime` option."
|
"mechanism by removing the `access_token_lifetime` option."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The fallback template used for authenticating using a registration token
|
||||||
|
self.registration_token_template = self.read_template("registration_token.html")
|
||||||
|
|
||||||
# The success template used during fallback auth.
|
# The success template used during fallback auth.
|
||||||
self.fallback_success_template = self.read_template("auth_success.html")
|
self.fallback_success_template = self.read_template("auth_success.html")
|
||||||
|
|
||||||
@ -199,6 +205,15 @@ class RegistrationConfig(Config):
|
|||||||
#
|
#
|
||||||
#enable_3pid_lookup: true
|
#enable_3pid_lookup: true
|
||||||
|
|
||||||
|
# Require users to submit a token during registration.
|
||||||
|
# Tokens can be managed using the admin API:
|
||||||
|
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
|
||||||
|
# Note that `enable_registration` must be set to `true`.
|
||||||
|
# Disabling this option will not delete any tokens previously generated.
|
||||||
|
# Defaults to false. Uncomment the following to require tokens:
|
||||||
|
#
|
||||||
|
#registration_requires_token: true
|
||||||
|
|
||||||
# If set, allows registration of standard or admin accounts by anyone who
|
# If set, allows registration of standard or admin accounts by anyone who
|
||||||
# has the shared secret, even if registration is otherwise disabled.
|
# has the shared secret, even if registration is otherwise disabled.
|
||||||
#
|
#
|
||||||
|
@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
|
|||||||
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
|
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
|
||||||
# for.
|
# for.
|
||||||
REQUEST_USER_ID = "request_user_id"
|
REQUEST_USER_ID = "request_user_id"
|
||||||
|
|
||||||
|
# used during registration to store the registration token used (if required) so that:
|
||||||
|
# - we can prevent a token being used twice by one session
|
||||||
|
# - we can 'use up' the token after registration has successfully completed
|
||||||
|
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
|
||||||
|
@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
|||||||
return await self._check_threepid("msisdn", authdict)
|
return await self._check_threepid("msisdn", authdict)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
|
||||||
|
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
self.hs = hs
|
||||||
|
self._enabled = bool(hs.config.registration_requires_token)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
return self._enabled
|
||||||
|
|
||||||
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
|
if "token" not in authdict:
|
||||||
|
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
|
||||||
|
if not isinstance(authdict["token"], str):
|
||||||
|
raise LoginError(
|
||||||
|
400, "Registration token must be a string", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
if "session" not in authdict:
|
||||||
|
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
# Get these here to avoid cyclic dependencies
|
||||||
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||||
|
|
||||||
|
auth_handler = self.hs.get_auth_handler()
|
||||||
|
|
||||||
|
session = authdict["session"]
|
||||||
|
token = authdict["token"]
|
||||||
|
|
||||||
|
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
|
||||||
|
# return early to avoid incrementing `pending` again.
|
||||||
|
stored_token = await auth_handler.get_session_data(
|
||||||
|
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
|
||||||
|
)
|
||||||
|
if stored_token:
|
||||||
|
if token != stored_token:
|
||||||
|
raise LoginError(
|
||||||
|
400, "Registration token has changed", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return token
|
||||||
|
|
||||||
|
if await self.store.registration_token_is_valid(token):
|
||||||
|
# Increment pending counter, so that if token has limited uses it
|
||||||
|
# can't be used up by someone else in the meantime.
|
||||||
|
await self.store.set_registration_token_pending(token)
|
||||||
|
# Store the token in the UIA session, so that once registration
|
||||||
|
# is complete `completed` can be incremented.
|
||||||
|
await auth_handler.set_session_data(
|
||||||
|
session,
|
||||||
|
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
||||||
|
token,
|
||||||
|
)
|
||||||
|
# The token will be stored as the result of the authentication stage
|
||||||
|
# in ui_auth_sessions_credentials. This allows the pending counter
|
||||||
|
# for tokens to be decremented when expired sessions are deleted.
|
||||||
|
return token
|
||||||
|
else:
|
||||||
|
raise LoginError(
|
||||||
|
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
INTERACTIVE_AUTH_CHECKERS = [
|
INTERACTIVE_AUTH_CHECKERS = [
|
||||||
DummyAuthChecker,
|
DummyAuthChecker,
|
||||||
TermsAuthChecker,
|
TermsAuthChecker,
|
||||||
RecaptchaAuthChecker,
|
RecaptchaAuthChecker,
|
||||||
EmailIdentityAuthChecker,
|
EmailIdentityAuthChecker,
|
||||||
MsisdnAuthChecker,
|
MsisdnAuthChecker,
|
||||||
|
RegistrationTokenAuthChecker,
|
||||||
]
|
]
|
||||||
"""A list of UserInteractiveAuthChecker classes"""
|
"""A list of UserInteractiveAuthChecker classes"""
|
||||||
|
23
synapse/res/templates/registration_token.html
Normal file
23
synapse/res/templates/registration_token.html
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Authentication</title>
|
||||||
|
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||||
|
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||||
|
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<form id="registrationForm" method="post" action="{{ myurl }}">
|
||||||
|
<div>
|
||||||
|
{% if error is defined %}
|
||||||
|
<p class="error"><strong>Error: {{ error }}</strong></p>
|
||||||
|
{% endif %}
|
||||||
|
<p>
|
||||||
|
Please enter a registration token.
|
||||||
|
</p>
|
||||||
|
<input type="hidden" name="session" value="{{ session }}" />
|
||||||
|
<input type="text" name="token" />
|
||||||
|
<input type="submit" value="Authenticate" />
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
|
|||||||
)
|
)
|
||||||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||||
|
from synapse.rest.admin.registration_tokens import (
|
||||||
|
ListRegistrationTokensRestServlet,
|
||||||
|
NewRegistrationTokenRestServlet,
|
||||||
|
RegistrationTokenRestServlet,
|
||||||
|
)
|
||||||
from synapse.rest.admin.rooms import (
|
from synapse.rest.admin.rooms import (
|
||||||
DeleteRoomRestServlet,
|
DeleteRoomRestServlet,
|
||||||
ForwardExtremitiesRestServlet,
|
ForwardExtremitiesRestServlet,
|
||||||
@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||||||
RoomEventContextServlet(hs).register(http_server)
|
RoomEventContextServlet(hs).register(http_server)
|
||||||
RateLimitRestServlet(hs).register(http_server)
|
RateLimitRestServlet(hs).register(http_server)
|
||||||
UsernameAvailableRestServlet(hs).register(http_server)
|
UsernameAvailableRestServlet(hs).register(http_server)
|
||||||
|
ListRegistrationTokensRestServlet(hs).register(http_server)
|
||||||
|
NewRegistrationTokenRestServlet(hs).register(http_server)
|
||||||
|
RegistrationTokenRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets_for_client_rest_resource(
|
def register_servlets_for_client_rest_resource(
|
||||||
|
321
synapse/rest/admin/registration_tokens.py
Normal file
321
synapse/rest/admin/registration_tokens.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
# Copyright 2021 Callum Brown
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import string
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||||
|
from synapse.http.servlet import (
|
||||||
|
RestServlet,
|
||||||
|
parse_boolean,
|
||||||
|
parse_json_object_from_request,
|
||||||
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ListRegistrationTokensRestServlet(RestServlet):
|
||||||
|
"""List registration tokens.
|
||||||
|
|
||||||
|
To list all tokens:
|
||||||
|
|
||||||
|
GET /_synapse/admin/v1/registration_tokens
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"registration_tokens": [
|
||||||
|
{
|
||||||
|
"token": "abcd",
|
||||||
|
"uses_allowed": 3,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 1,
|
||||||
|
"expiry_time": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token": "wxyz",
|
||||||
|
"uses_allowed": null,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 9,
|
||||||
|
"expiry_time": 1625394937000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
The optional query parameter `valid` can be used to filter the response.
|
||||||
|
If it is `true`, only valid tokens are returned. If it is `false`, only
|
||||||
|
tokens that have expired or have had all uses exhausted are returned.
|
||||||
|
If it is omitted, all tokens are returned regardless of validity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/registration_tokens$")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
valid = parse_boolean(request, "valid")
|
||||||
|
token_list = await self.store.get_registration_tokens(valid)
|
||||||
|
return 200, {"registration_tokens": token_list}
|
||||||
|
|
||||||
|
|
||||||
|
class NewRegistrationTokenRestServlet(RestServlet):
|
||||||
|
"""Create a new registration token.
|
||||||
|
|
||||||
|
For example, to create a token specifying some fields:
|
||||||
|
|
||||||
|
POST /_synapse/admin/v1/registration_tokens/new
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "defg",
|
||||||
|
"uses_allowed": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "defg",
|
||||||
|
"uses_allowed": 1,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": null
|
||||||
|
}
|
||||||
|
|
||||||
|
Defaults are used for any fields not specified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/registration_tokens/new$")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
# A string of all the characters allowed to be in a registration_token
|
||||||
|
self.allowed_chars = string.ascii_letters + string.digits + "-_"
|
||||||
|
self.allowed_chars_set = set(self.allowed_chars)
|
||||||
|
|
||||||
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
if "token" in body:
|
||||||
|
token = body["token"]
|
||||||
|
if not isinstance(token, str):
|
||||||
|
raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
|
||||||
|
if not (0 < len(token) <= 64):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"token must not be empty and must not be longer than 64 characters",
|
||||||
|
Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
if not set(token).issubset(self.allowed_chars_set):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"token must consist only of characters matched by the regex [A-Za-z0-9-_]",
|
||||||
|
Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Get length of token to generate (default is 16)
|
||||||
|
length = body.get("length", 16)
|
||||||
|
if not isinstance(length, int):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "length must be an integer", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
if not (0 < length <= 64):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"length must be greater than zero and not greater than 64",
|
||||||
|
Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate token
|
||||||
|
token = await self.store.generate_registration_token(
|
||||||
|
length, self.allowed_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
uses_allowed = body.get("uses_allowed", None)
|
||||||
|
if not (
|
||||||
|
uses_allowed is None
|
||||||
|
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
||||||
|
):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"uses_allowed must be a non-negative integer or null",
|
||||||
|
Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
|
expiry_time = body.get("expiry_time", None)
|
||||||
|
if not isinstance(expiry_time, (int, type(None))):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
||||||
|
raise SynapseError(
|
||||||
|
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
created = await self.store.create_registration_token(
|
||||||
|
token, uses_allowed, expiry_time
|
||||||
|
)
|
||||||
|
if not created:
|
||||||
|
raise SynapseError(
|
||||||
|
400, f"Token already exists: {token}", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": uses_allowed,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": expiry_time,
|
||||||
|
}
|
||||||
|
return 200, resp
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationTokenRestServlet(RestServlet):
|
||||||
|
"""Retrieve, update, or delete the given token.
|
||||||
|
|
||||||
|
For example,
|
||||||
|
|
||||||
|
to retrieve a token:
|
||||||
|
|
||||||
|
GET /_synapse/admin/v1/registration_tokens/abcd
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "abcd",
|
||||||
|
"uses_allowed": 3,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 1,
|
||||||
|
"expiry_time": null
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
to update a token:
|
||||||
|
|
||||||
|
PUT /_synapse/admin/v1/registration_tokens/defg
|
||||||
|
|
||||||
|
{
|
||||||
|
"uses_allowed": 5,
|
||||||
|
"expiry_time": 4781243146000
|
||||||
|
}
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "defg",
|
||||||
|
"uses_allowed": 5,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": 4781243146000
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
to delete a token:
|
||||||
|
|
||||||
|
DELETE /_synapse/admin/v1/registration_tokens/wxyz
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
|
||||||
|
"""Retrieve a registration token."""
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
token_info = await self.store.get_one_registration_token(token)
|
||||||
|
|
||||||
|
# If no result return a 404
|
||||||
|
if token_info is None:
|
||||||
|
raise NotFoundError(f"No such registration token: {token}")
|
||||||
|
|
||||||
|
return 200, token_info
|
||||||
|
|
||||||
|
async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
|
||||||
|
"""Update a registration token."""
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
new_attributes = {}
|
||||||
|
|
||||||
|
# Only add uses_allowed to new_attributes if it is present and valid
|
||||||
|
if "uses_allowed" in body:
|
||||||
|
uses_allowed = body["uses_allowed"]
|
||||||
|
if not (
|
||||||
|
uses_allowed is None
|
||||||
|
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
||||||
|
):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"uses_allowed must be a non-negative integer or null",
|
||||||
|
Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
new_attributes["uses_allowed"] = uses_allowed
|
||||||
|
|
||||||
|
if "expiry_time" in body:
|
||||||
|
expiry_time = body["expiry_time"]
|
||||||
|
if not isinstance(expiry_time, (int, type(None))):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
||||||
|
raise SynapseError(
|
||||||
|
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
new_attributes["expiry_time"] = expiry_time
|
||||||
|
|
||||||
|
if len(new_attributes) == 0:
|
||||||
|
# Nothing to update, get token info to return
|
||||||
|
token_info = await self.store.get_one_registration_token(token)
|
||||||
|
else:
|
||||||
|
token_info = await self.store.update_registration_token(
|
||||||
|
token, new_attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no result return a 404
|
||||||
|
if token_info is None:
|
||||||
|
raise NotFoundError(f"No such registration token: {token}")
|
||||||
|
|
||||||
|
return 200, token_info
|
||||||
|
|
||||||
|
async def on_DELETE(
|
||||||
|
self, request: SynapseRequest, token: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
"""Delete a registration token."""
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
|
if await self.store.delete_registration_token(token):
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
raise NotFoundError(f"No such registration token: {token}")
|
@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
|
|||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self.recaptcha_template = hs.config.recaptcha_template
|
self.recaptcha_template = hs.config.recaptcha_template
|
||||||
self.terms_template = hs.config.terms_template
|
self.terms_template = hs.config.terms_template
|
||||||
|
self.registration_token_template = hs.config.registration_token_template
|
||||||
self.success_template = hs.config.fallback_success_template
|
self.success_template = hs.config.fallback_success_template
|
||||||
|
|
||||||
async def on_GET(self, request, stagetype):
|
async def on_GET(self, request, stagetype):
|
||||||
@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
|
|||||||
# re-authenticate with their SSO provider.
|
# re-authenticate with their SSO provider.
|
||||||
html = await self.auth_handler.start_sso_ui_auth(request, session)
|
html = await self.auth_handler.start_sso_ui_auth(request, session)
|
||||||
|
|
||||||
|
elif stagetype == LoginType.REGISTRATION_TOKEN:
|
||||||
|
html = self.registration_token_template.render(
|
||||||
|
session=session,
|
||||||
|
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise SynapseError(404, "Unknown auth stage type")
|
raise SynapseError(404, "Unknown auth stage type")
|
||||||
|
|
||||||
@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
|
|||||||
# The SSO fallback workflow should not post here,
|
# The SSO fallback workflow should not post here,
|
||||||
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
|
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
|
||||||
|
|
||||||
|
elif stagetype == LoginType.REGISTRATION_TOKEN:
|
||||||
|
token = parse_string(request, "token", required=True)
|
||||||
|
authdict = {"session": session, "token": token}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.auth_handler.add_oob_auth(
|
||||||
|
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
|
||||||
|
)
|
||||||
|
except LoginError as e:
|
||||||
|
html = self.registration_token_template.render(
|
||||||
|
session=session,
|
||||||
|
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
|
||||||
|
error=e.msg,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
html = self.success_template.render()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise SynapseError(404, "Unknown auth stage type")
|
raise SynapseError(404, "Unknown auth stage type")
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ from synapse.api.errors import (
|
|||||||
ThreepidValidationError,
|
ThreepidValidationError,
|
||||||
UnrecognizedRequestError,
|
UnrecognizedRequestError,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.captcha import CaptchaConfig
|
from synapse.config.captcha import CaptchaConfig
|
||||||
from synapse.config.consent import ConsentConfig
|
from synapse.config.consent import ConsentConfig
|
||||||
@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
|||||||
return 200, {"available": True}
|
return 200, {"available": True}
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationTokenValidityRestServlet(RestServlet):
|
||||||
|
"""Check the validity of a registration token.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
"valid": true
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
|
||||||
|
releases=(),
|
||||||
|
unstable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
|
clock=hs.get_clock(),
|
||||||
|
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
|
||||||
|
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_GET(self, request):
|
||||||
|
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
|
||||||
|
|
||||||
|
if not self.hs.config.enable_registration:
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
|
||||||
|
)
|
||||||
|
|
||||||
|
token = parse_string(request, "token", required=True)
|
||||||
|
valid = await self.store.registration_token_is_valid(token)
|
||||||
|
|
||||||
|
return 200, {"valid": valid}
|
||||||
|
|
||||||
|
|
||||||
class RegisterRestServlet(RestServlet):
|
class RegisterRestServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/register$")
|
PATTERNS = client_patterns("/register$")
|
||||||
|
|
||||||
@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if registered:
|
if registered:
|
||||||
|
# Check if a token was used to authenticate registration
|
||||||
|
registration_token = await self.auth_handler.get_session_data(
|
||||||
|
session_id,
|
||||||
|
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
||||||
|
)
|
||||||
|
if registration_token:
|
||||||
|
# Increment the `completed` counter for the token
|
||||||
|
await self.store.use_registration_token(registration_token)
|
||||||
|
# Indicate that the token has been successfully used so that
|
||||||
|
# pending is not decremented again when expiring old UIA sessions.
|
||||||
|
await self.store.mark_ui_auth_stage_complete(
|
||||||
|
session_id,
|
||||||
|
LoginType.REGISTRATION_TOKEN,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
await self.registration_handler.post_registration_actions(
|
await self.registration_handler.post_registration_actions(
|
||||||
user_id=registered_user_id,
|
user_id=registered_user_id,
|
||||||
auth_result=auth_result,
|
auth_result=auth_result,
|
||||||
@ -868,6 +934,11 @@ def _calculate_registration_flows(
|
|||||||
for flow in flows:
|
for flow in flows:
|
||||||
flow.insert(0, LoginType.RECAPTCHA)
|
flow.insert(0, LoginType.RECAPTCHA)
|
||||||
|
|
||||||
|
# Prepend registration token to all flows if we're requiring a token
|
||||||
|
if config.registration_requires_token:
|
||||||
|
for flow in flows:
|
||||||
|
flow.insert(0, LoginType.REGISTRATION_TOKEN)
|
||||||
|
|
||||||
return flows
|
return flows
|
||||||
|
|
||||||
|
|
||||||
@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
|
|||||||
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||||
UsernameAvailabilityRestServlet(hs).register(http_server)
|
UsernameAvailabilityRestServlet(hs).register(http_server)
|
||||||
RegistrationSubmitTokenServlet(hs).register(http_server)
|
RegistrationSubmitTokenServlet(hs).register(http_server)
|
||||||
|
RegistrationTokenValidityRestServlet(hs).register(http_server)
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
desc="update_access_token_last_validated",
|
desc="update_access_token_last_validated",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def registration_token_is_valid(self, token: str) -> bool:
|
||||||
|
"""Checks if a token can be used to authenticate a registration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The registration token to be checked
|
||||||
|
Returns:
|
||||||
|
True if the token is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
res = await self.db_pool.simple_select_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the token exists
|
||||||
|
if res is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the token has expired
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
if res["expiry_time"] and res["expiry_time"] < now:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the token has been used up
|
||||||
|
if (
|
||||||
|
res["uses_allowed"]
|
||||||
|
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Otherwise, the token is valid
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def set_registration_token_pending(self, token: str) -> None:
|
||||||
|
"""Increment the pending registrations counter for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The registration token pending use
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _set_registration_token_pending_txn(txn):
|
||||||
|
pending = self.db_pool.simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcol="pending",
|
||||||
|
)
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
updatevalues={"pending": pending + 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"set_registration_token_pending", _set_registration_token_pending_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def use_registration_token(self, token: str) -> None:
|
||||||
|
"""Complete a use of the given registration token.
|
||||||
|
|
||||||
|
The `pending` counter will be decremented, and the `completed`
|
||||||
|
counter will be incremented.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The registration token to be 'used'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _use_registration_token_txn(txn):
|
||||||
|
# Normally, res is Optional[Dict[str, Any]].
|
||||||
|
# Override type because the return type is only optional if
|
||||||
|
# allow_none is True, and we don't want mypy throwing errors
|
||||||
|
# about None not being indexable.
|
||||||
|
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=["pending", "completed"],
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Decrement pending and increment completed
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
updatevalues={
|
||||||
|
"completed": res["completed"] + 1,
|
||||||
|
"pending": res["pending"] - 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"use_registration_token", _use_registration_token_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_registration_tokens(
|
||||||
|
self, valid: Optional[bool] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""List all registration tokens. Used by the admin API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
valid: If True, only valid tokens are returned.
|
||||||
|
If False, only invalid tokens are returned.
|
||||||
|
Default is None: return all tokens regardless of validity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dicts, each containing details of a token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
|
||||||
|
if valid is None:
|
||||||
|
# Return all tokens regardless of validity
|
||||||
|
txn.execute("SELECT * FROM registration_tokens")
|
||||||
|
|
||||||
|
elif valid:
|
||||||
|
# Select valid tokens only
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM registration_tokens WHERE "
|
||||||
|
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
|
||||||
|
"AND (expiry_time > ? OR expiry_time IS NULL)"
|
||||||
|
)
|
||||||
|
txn.execute(sql, [now])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Select invalid tokens only
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM registration_tokens WHERE "
|
||||||
|
"uses_allowed <= pending + completed OR expiry_time <= ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, [now])
|
||||||
|
|
||||||
|
return self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"select_registration_tokens",
|
||||||
|
select_registration_tokens_txn,
|
||||||
|
self._clock.time_msec(),
|
||||||
|
valid,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get info about the given registration token. Used by the admin API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to retrieve information about.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict, or None if token doesn't exist.
|
||||||
|
"""
|
||||||
|
return await self.db_pool.simple_select_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_one_registration_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_registration_token(
|
||||||
|
self, length: int, chars: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Generate a random registration token. Used by the admin API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: The length of the token to generate.
|
||||||
|
chars: A string of the characters allowed in the generated token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if a unique registration token could still not be
|
||||||
|
generated after a few tries.
|
||||||
|
"""
|
||||||
|
# Make a few attempts at generating a unique token of the required
|
||||||
|
# length before failing.
|
||||||
|
for _i in range(3):
|
||||||
|
# Generate token
|
||||||
|
token = "".join(random.choices(chars, k=length))
|
||||||
|
|
||||||
|
# Check if the token already exists
|
||||||
|
existing_token = await self.db_pool.simple_select_one_onecol(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcol="token",
|
||||||
|
allow_none=True,
|
||||||
|
desc="check_if_registration_token_exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_token is None:
|
||||||
|
# The generated token doesn't exist yet, return it
|
||||||
|
return token
|
||||||
|
|
||||||
|
raise SynapseError(
|
||||||
|
500,
|
||||||
|
"Unable to generate a unique registration token. Try again with a greater length",
|
||||||
|
Codes.UNKNOWN,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_registration_token(
|
||||||
|
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
|
||||||
|
) -> bool:
|
||||||
|
"""Create a new registration token. Used by the admin API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to create.
|
||||||
|
uses_allowed: The number of times the token can be used to complete
|
||||||
|
a registration before it becomes invalid. A value of None indicates
|
||||||
|
unlimited uses.
|
||||||
|
expiry_time: The latest time the token is valid. Given as the
|
||||||
|
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
|
||||||
|
None indicates that the token does not expire.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the row was inserted or not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _create_registration_token_txn(txn):
|
||||||
|
row = self.db_pool.simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=["token"],
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if row is not None:
|
||||||
|
# Token already exists
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.db_pool.simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
values={
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": uses_allowed,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": expiry_time,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"create_registration_token", _create_registration_token_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_registration_token(
|
||||||
|
self, token: str, updatevalues: Dict[str, Optional[int]]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Update a registration token. Used by the admin API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to update.
|
||||||
|
updatevalues: A dict with the fields to update. E.g.:
|
||||||
|
`{"uses_allowed": 3}` to update just uses_allowed, or
|
||||||
|
`{"uses_allowed": 3, "expiry_time": None}` to update both.
|
||||||
|
This is passed straight to simple_update_one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with all info about the token, or None if token doesn't exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _update_registration_token_txn(txn):
|
||||||
|
try:
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
updatevalues=updatevalues,
|
||||||
|
)
|
||||||
|
except StoreError:
|
||||||
|
# Update failed because token does not exist
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get all info about the token so it can be sent in the response
|
||||||
|
return self.db_pool.simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=[
|
||||||
|
"token",
|
||||||
|
"uses_allowed",
|
||||||
|
"pending",
|
||||||
|
"completed",
|
||||||
|
"expiry_time",
|
||||||
|
],
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"update_registration_token", _update_registration_token_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_registration_token(self, token: str) -> bool:
|
||||||
|
"""Delete a registration token. Used by the admin API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the token was successfully deleted or not.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.db_pool.simple_delete_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
desc="delete_registration_token",
|
||||||
|
)
|
||||||
|
except StoreError:
|
||||||
|
# Deletion failed because token does not exist
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def mark_access_token_as_used(self, token_id: int) -> None:
|
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
keyvalues={},
|
keyvalues={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If a registration token was used, decrement the pending counter
|
||||||
|
# before deleting the session.
|
||||||
|
rows = self.db_pool.simple_select_many_txn(
|
||||||
|
txn,
|
||||||
|
table="ui_auth_sessions_credentials",
|
||||||
|
column="session_id",
|
||||||
|
iterable=session_ids,
|
||||||
|
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
||||||
|
retcols=["result"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the tokens used and how much pending needs to be decremented by.
|
||||||
|
token_counts: Dict[str, int] = {}
|
||||||
|
for r in rows:
|
||||||
|
# If registration was successfully completed, the result of the
|
||||||
|
# registration token stage for that session will be True.
|
||||||
|
# If a token was used to authenticate, but registration was
|
||||||
|
# never completed, the result will be the token used.
|
||||||
|
token = db_to_json(r["result"])
|
||||||
|
if isinstance(token, str):
|
||||||
|
token_counts[token] = token_counts.get(token, 0) + 1
|
||||||
|
|
||||||
|
# Update the `pending` counters.
|
||||||
|
if len(token_counts) > 0:
|
||||||
|
token_rows = self.db_pool.simple_select_many_txn(
|
||||||
|
txn,
|
||||||
|
table="registration_tokens",
|
||||||
|
column="token",
|
||||||
|
iterable=list(token_counts.keys()),
|
||||||
|
keyvalues={},
|
||||||
|
retcols=["token", "pending"],
|
||||||
|
)
|
||||||
|
for token_row in token_rows:
|
||||||
|
token = token_row["token"]
|
||||||
|
new_pending = token_row["pending"] - token_counts[token]
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
table="registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
updatevalues={"pending": new_pending},
|
||||||
|
)
|
||||||
|
|
||||||
# Delete the corresponding completed credentials.
|
# Delete the corresponding completed credentials.
|
||||||
self.db_pool.simple_delete_many_txn(
|
self.db_pool.simple_delete_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2021 Callum Brown
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS registration_tokens(
|
||||||
|
token TEXT NOT NULL, -- The token that can be used for authentication.
|
||||||
|
uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
|
||||||
|
pending INT NOT NULL, -- The number of in progress registrations using this token.
|
||||||
|
completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
|
||||||
|
expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
|
||||||
|
UNIQUE (token)
|
||||||
|
);
|
710
tests/rest/admin/test_registration_tokens.py
Normal file
710
tests/rest/admin/test_registration_tokens.py
Normal file
@ -0,0 +1,710 @@
|
|||||||
|
# Copyright 2021 Callum Brown
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
import synapse.rest.admin
|
||||||
|
from synapse.api.errors import Codes
|
||||||
|
from synapse.rest.client import login
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
|
self.other_user = self.register_user("user", "pass")
|
||||||
|
self.other_user_tok = self.login("user", "pass")
|
||||||
|
|
||||||
|
self.url = "/_synapse/admin/v1/registration_tokens"
|
||||||
|
|
||||||
|
def _new_token(self, **kwargs):
|
||||||
|
"""Helper function to create a token."""
|
||||||
|
token = kwargs.get(
|
||||||
|
"token",
|
||||||
|
"".join(random.choices(string.ascii_letters, k=8)),
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": kwargs.get("uses_allowed", None),
|
||||||
|
"pending": kwargs.get("pending", 0),
|
||||||
|
"completed": kwargs.get("completed", 0),
|
||||||
|
"expiry_time": kwargs.get("expiry_time", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return token
|
||||||
|
|
||||||
|
# CREATION
|
||||||
|
|
||||||
|
def test_create_no_auth(self):
|
||||||
|
"""Try to create a token without authentication."""
|
||||||
|
channel = self.make_request("POST", self.url + "/new", {})
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_create_requester_not_admin(self):
|
||||||
|
"""Try to create a token while not an admin."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{},
|
||||||
|
access_token=self.other_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_create_using_defaults(self):
|
||||||
|
"""Create a token using all the defaults."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(len(channel.json_body["token"]), 16)
|
||||||
|
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
self.assertEqual(channel.json_body["pending"], 0)
|
||||||
|
self.assertEqual(channel.json_body["completed"], 0)
|
||||||
|
|
||||||
|
def test_create_specifying_fields(self):
|
||||||
|
"""Create a token specifying the value of all fields."""
|
||||||
|
data = {
|
||||||
|
"token": "abcd",
|
||||||
|
"uses_allowed": 1,
|
||||||
|
"expiry_time": self.clock.time_msec() + 1000000,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["token"], "abcd")
|
||||||
|
self.assertEqual(channel.json_body["uses_allowed"], 1)
|
||||||
|
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
|
||||||
|
self.assertEqual(channel.json_body["pending"], 0)
|
||||||
|
self.assertEqual(channel.json_body["completed"], 0)
|
||||||
|
|
||||||
|
def test_create_with_null_value(self):
|
||||||
|
"""Create a token specifying unlimited uses and no expiry."""
|
||||||
|
data = {
|
||||||
|
"uses_allowed": None,
|
||||||
|
"expiry_time": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(len(channel.json_body["token"]), 16)
|
||||||
|
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
self.assertEqual(channel.json_body["pending"], 0)
|
||||||
|
self.assertEqual(channel.json_body["completed"], 0)
|
||||||
|
|
||||||
|
def test_create_token_too_long(self):
|
||||||
|
"""Check token longer than 64 chars is invalid."""
|
||||||
|
data = {"token": "a" * 65}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_create_token_invalid_chars(self):
|
||||||
|
"""Check you can't create token with invalid characters."""
|
||||||
|
data = {
|
||||||
|
"token": "abc/def",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_create_token_already_exists(self):
|
||||||
|
"""Check you can't create token that already exists."""
|
||||||
|
data = {
|
||||||
|
"token": "abcd",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel1 = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
|
||||||
|
|
||||||
|
channel2 = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
|
||||||
|
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_create_unable_to_generate_token(self):
|
||||||
|
"""Check right error is raised when server can't generate unique token."""
|
||||||
|
# Create all possible single character tokens
|
||||||
|
tokens = []
|
||||||
|
for c in string.ascii_letters + string.digits + "-_":
|
||||||
|
tokens.append(
|
||||||
|
{
|
||||||
|
"token": c,
|
||||||
|
"uses_allowed": None,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.store.db_pool.simple_insert_many(
|
||||||
|
"registration_tokens",
|
||||||
|
tokens,
|
||||||
|
"create_all_registration_tokens",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check creating a single character token fails with a 500 status code
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"length": 1},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
def test_create_uses_allowed(self):
|
||||||
|
"""Check you can only create a token with good values for uses_allowed."""
|
||||||
|
# Should work with 0 (token is invalid from the start)
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"uses_allowed": 0},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["uses_allowed"], 0)
|
||||||
|
|
||||||
|
# Should fail with negative integer
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"uses_allowed": -5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail with float
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"uses_allowed": 1.5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_create_expiry_time(self):
|
||||||
|
"""Check you can't create a token with an invalid expiry_time."""
|
||||||
|
# Should fail with a time in the past
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"expiry_time": self.clock.time_msec() - 10000},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail with float
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"expiry_time": self.clock.time_msec() + 1000000.5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_create_length(self):
|
||||||
|
"""Check you can only generate a token with a valid length."""
|
||||||
|
# Should work with 64
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"length": 64},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(len(channel.json_body["token"]), 64)
|
||||||
|
|
||||||
|
# Should fail with 0
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"length": 0},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail with a negative integer
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"length": -5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail with a float
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"length": 8.5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail with 65
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url + "/new",
|
||||||
|
{"length": 65},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# UPDATING
|
||||||
|
|
||||||
|
def test_update_no_auth(self):
|
||||||
|
"""Try to update a token without authentication."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_update_requester_not_admin(self):
|
||||||
|
"""Try to update a token while not an admin."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||||
|
{},
|
||||||
|
access_token=self.other_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_update_non_existent(self):
|
||||||
|
"""Try to update a token that doesn't exist."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/1234",
|
||||||
|
{"uses_allowed": 1},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||||
|
|
||||||
|
def test_update_uses_allowed(self):
|
||||||
|
"""Test updating just uses_allowed."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
|
||||||
|
# Should succeed with 1
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"uses_allowed": 1},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["uses_allowed"], 1)
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
|
||||||
|
# Should succeed with 0 (makes token invalid)
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"uses_allowed": 0},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["uses_allowed"], 0)
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
|
||||||
|
# Should succeed with null
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"uses_allowed": None},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
|
||||||
|
# Should fail with a float
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"uses_allowed": 1.5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail with a negative integer
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"uses_allowed": -5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_update_expiry_time(self):
|
||||||
|
"""Test updating just expiry_time."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
new_expiry_time = self.clock.time_msec() + 1000000
|
||||||
|
|
||||||
|
# Should succeed with a time in the future
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"expiry_time": new_expiry_time},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
|
||||||
|
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||||
|
|
||||||
|
# Should succeed with null
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"expiry_time": None},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||||
|
|
||||||
|
# Should fail with a time in the past
|
||||||
|
past_time = self.clock.time_msec() - 10000
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"expiry_time": past_time},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# Should fail a float
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{"expiry_time": new_expiry_time + 0.5},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
def test_update_both(self):
|
||||||
|
"""Test updating both uses_allowed and expiry_time."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
new_expiry_time = self.clock.time_msec() + 1000000
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"uses_allowed": 1,
|
||||||
|
"expiry_time": new_expiry_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["uses_allowed"], 1)
|
||||||
|
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
|
||||||
|
|
||||||
|
def test_update_invalid_type(self):
|
||||||
|
"""Test using invalid types doesn't work."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"uses_allowed": False,
|
||||||
|
"expiry_time": "1626430124000",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url + "/" + token,
|
||||||
|
data,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# DELETING
|
||||||
|
|
||||||
|
def test_delete_no_auth(self):
|
||||||
|
"""Try to delete a token without authentication."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"DELETE",
|
||||||
|
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_delete_requester_not_admin(self):
|
||||||
|
"""Try to delete a token while not an admin."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"DELETE",
|
||||||
|
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||||
|
{},
|
||||||
|
access_token=self.other_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_delete_non_existent(self):
|
||||||
|
"""Try to delete a token that doesn't exist."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"DELETE",
|
||||||
|
self.url + "/1234",
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
"""Test deleting a token."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"DELETE",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# GETTING ONE
|
||||||
|
|
||||||
|
def test_get_no_auth(self):
|
||||||
|
"""Try to get a token without authentication."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_get_requester_not_admin(self):
|
||||||
|
"""Try to get a token while not an admin."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||||
|
{},
|
||||||
|
access_token=self.other_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_get_non_existent(self):
|
||||||
|
"""Try to get a token that doesn't exist."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "/1234",
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
"""Test getting a token."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "/" + token,
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["token"], token)
|
||||||
|
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||||
|
self.assertIsNone(channel.json_body["expiry_time"])
|
||||||
|
self.assertEqual(channel.json_body["pending"], 0)
|
||||||
|
self.assertEqual(channel.json_body["completed"], 0)
|
||||||
|
|
||||||
|
# LISTING
|
||||||
|
|
||||||
|
def test_list_no_auth(self):
|
||||||
|
"""Try to list tokens without authentication."""
|
||||||
|
channel = self.make_request("GET", self.url, {})
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_list_requester_not_admin(self):
|
||||||
|
"""Try to list tokens while not an admin."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url,
|
||||||
|
{},
|
||||||
|
access_token=self.other_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_list_all(self):
|
||||||
|
"""Test listing all tokens."""
|
||||||
|
# Create new token using default values
|
||||||
|
token = self._new_token()
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url,
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
|
||||||
|
token_info = channel.json_body["registration_tokens"][0]
|
||||||
|
self.assertEqual(token_info["token"], token)
|
||||||
|
self.assertIsNone(token_info["uses_allowed"])
|
||||||
|
self.assertIsNone(token_info["expiry_time"])
|
||||||
|
self.assertEqual(token_info["pending"], 0)
|
||||||
|
self.assertEqual(token_info["completed"], 0)
|
||||||
|
|
||||||
|
def test_list_invalid_query_parameter(self):
|
||||||
|
"""Test with `valid` query parameter not `true` or `false`."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?valid=x",
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
def _test_list_query_parameter(self, valid: str):
|
||||||
|
"""Helper used to test both valid=true and valid=false."""
|
||||||
|
# Create 2 valid and 2 invalid tokens.
|
||||||
|
now = self.hs.get_clock().time_msec()
|
||||||
|
# Create always valid token
|
||||||
|
valid1 = self._new_token()
|
||||||
|
# Create token that hasn't been used up
|
||||||
|
valid2 = self._new_token(uses_allowed=1)
|
||||||
|
# Create token that has expired
|
||||||
|
invalid1 = self._new_token(expiry_time=now - 10000)
|
||||||
|
# Create token that has been used up but hasn't expired
|
||||||
|
invalid2 = self._new_token(
|
||||||
|
uses_allowed=2,
|
||||||
|
pending=1,
|
||||||
|
completed=1,
|
||||||
|
expiry_time=now + 1000000,
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid == "true":
|
||||||
|
tokens = [valid1, valid2]
|
||||||
|
else:
|
||||||
|
tokens = [invalid1, invalid2]
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?valid=" + valid,
|
||||||
|
{},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
|
||||||
|
token_info_1 = channel.json_body["registration_tokens"][0]
|
||||||
|
token_info_2 = channel.json_body["registration_tokens"][1]
|
||||||
|
self.assertIn(token_info_1["token"], tokens)
|
||||||
|
self.assertIn(token_info_2["token"], tokens)
|
||||||
|
|
||||||
|
def test_list_valid(self):
|
||||||
|
"""Test listing just valid tokens."""
|
||||||
|
self._test_list_query_parameter(valid="true")
|
||||||
|
|
||||||
|
def test_list_invalid(self):
|
||||||
|
"""Test listing just invalid tokens."""
|
||||||
|
self._test_list_query_parameter(valid="false")
|
@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
|||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
||||||
|
from synapse.storage._base import db_to_json
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
@override_config({"registration_requires_token": True})
|
||||||
|
def test_POST_registration_requires_token(self):
|
||||||
|
username = "kermit"
|
||||||
|
device_id = "frogfone"
|
||||||
|
token = "abcd"
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": None,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
params = {
|
||||||
|
"username": username,
|
||||||
|
"password": "monkey",
|
||||||
|
"device_id": device_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Request without auth to get flows and session
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
# Synapse adds a dummy stage to differentiate flows where otherwise one
|
||||||
|
# flow would be a subset of another flow.
|
||||||
|
self.assertCountEqual(
|
||||||
|
[[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
|
||||||
|
(f["stages"] for f in flows),
|
||||||
|
)
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
# Do the registration token stage and check it has completed
|
||||||
|
params["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
request_data = json.dumps(params)
|
||||||
|
channel = self.make_request(b"POST", self.url, request_data)
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
completed = channel.json_body["completed"]
|
||||||
|
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||||
|
|
||||||
|
# Do the m.login.dummy stage and check registration was successful
|
||||||
|
params["auth"] = {
|
||||||
|
"type": LoginType.DUMMY,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
request_data = json.dumps(params)
|
||||||
|
channel = self.make_request(b"POST", self.url, request_data)
|
||||||
|
det_data = {
|
||||||
|
"user_id": f"@{username}:{self.hs.hostname}",
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
"device_id": device_id,
|
||||||
|
}
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
|
|
||||||
|
# Check the `completed` counter has been incremented and pending is 0
|
||||||
|
res = self.get_success(
|
||||||
|
store.db_pool.simple_select_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=["pending", "completed"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(res["completed"], 1)
|
||||||
|
self.assertEquals(res["pending"], 0)
|
||||||
|
|
||||||
|
@override_config({"registration_requires_token": True})
|
||||||
|
def test_POST_registration_token_invalid(self):
|
||||||
|
params = {
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey",
|
||||||
|
}
|
||||||
|
# Request without auth to get session
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
# Test with token param missing (invalid)
|
||||||
|
params["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||||
|
self.assertEquals(channel.json_body["completed"], [])
|
||||||
|
|
||||||
|
# Test with non-string (invalid)
|
||||||
|
params["auth"]["token"] = 1234
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||||
|
self.assertEquals(channel.json_body["completed"], [])
|
||||||
|
|
||||||
|
# Test with unknown token (invalid)
|
||||||
|
params["auth"]["token"] = "1234"
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||||
|
self.assertEquals(channel.json_body["completed"], [])
|
||||||
|
|
||||||
|
@override_config({"registration_requires_token": True})
|
||||||
|
def test_POST_registration_token_limit_uses(self):
|
||||||
|
token = "abcd"
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
# Create token that can be used once
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": 1,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
params1 = {"username": "bert", "password": "monkey"}
|
||||||
|
params2 = {"username": "ernie", "password": "monkey"}
|
||||||
|
# Do 2 requests without auth to get two session IDs
|
||||||
|
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
session1 = channel1.json_body["session"]
|
||||||
|
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||||
|
session2 = channel2.json_body["session"]
|
||||||
|
|
||||||
|
# Use token with session1 and check `pending` is 1
|
||||||
|
params1["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session1,
|
||||||
|
}
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
# Repeat request to make sure pending isn't increased again
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
pending = self.get_success(
|
||||||
|
store.db_pool.simple_select_one_onecol(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcol="pending",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(pending, 1)
|
||||||
|
|
||||||
|
# Check auth fails when using token with session2
|
||||||
|
params2["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session2,
|
||||||
|
}
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||||
|
self.assertEquals(channel.json_body["completed"], [])
|
||||||
|
|
||||||
|
# Complete registration with session1
|
||||||
|
params1["auth"]["type"] = LoginType.DUMMY
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
# Check pending=0 and completed=1
|
||||||
|
res = self.get_success(
|
||||||
|
store.db_pool.simple_select_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=["pending", "completed"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(res["pending"], 0)
|
||||||
|
self.assertEquals(res["completed"], 1)
|
||||||
|
|
||||||
|
# Check auth still fails when using token with session2
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||||
|
self.assertEquals(channel.json_body["completed"], [])
|
||||||
|
|
||||||
|
@override_config({"registration_requires_token": True})
|
||||||
|
def test_POST_registration_token_expiry(self):
|
||||||
|
token = "abcd"
|
||||||
|
now = self.hs.get_clock().time_msec()
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
# Create token that expired yesterday
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": None,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": now - 24 * 60 * 60 * 1000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
params = {"username": "kermit", "password": "monkey"}
|
||||||
|
# Request without auth to get session
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
# Check authentication fails with expired token
|
||||||
|
params["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||||
|
self.assertEquals(channel.json_body["completed"], [])
|
||||||
|
|
||||||
|
# Update token so it expires tomorrow
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_update_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check authentication succeeds
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
completed = channel.json_body["completed"]
|
||||||
|
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||||
|
|
||||||
|
@override_config({"registration_requires_token": True})
|
||||||
|
def test_POST_registration_token_session_expiry(self):
|
||||||
|
"""Test `pending` is decremented when an uncompleted session expires."""
|
||||||
|
token = "abcd"
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": None,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Do 2 requests without auth to get two session IDs
|
||||||
|
params1 = {"username": "bert", "password": "monkey"}
|
||||||
|
params2 = {"username": "ernie", "password": "monkey"}
|
||||||
|
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
session1 = channel1.json_body["session"]
|
||||||
|
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||||
|
session2 = channel2.json_body["session"]
|
||||||
|
|
||||||
|
# Use token with both sessions
|
||||||
|
params1["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session1,
|
||||||
|
}
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
|
||||||
|
params2["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session2,
|
||||||
|
}
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params2))
|
||||||
|
|
||||||
|
# Complete registration with session1
|
||||||
|
params1["auth"]["type"] = LoginType.DUMMY
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
|
|
||||||
|
# Check `result` of registration token stage for session1 is `True`
|
||||||
|
result1 = self.get_success(
|
||||||
|
store.db_pool.simple_select_one_onecol(
|
||||||
|
"ui_auth_sessions_credentials",
|
||||||
|
keyvalues={
|
||||||
|
"session_id": session1,
|
||||||
|
"stage_type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
},
|
||||||
|
retcol="result",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(db_to_json(result1))
|
||||||
|
|
||||||
|
# Check `result` for session2 is the token used
|
||||||
|
result2 = self.get_success(
|
||||||
|
store.db_pool.simple_select_one_onecol(
|
||||||
|
"ui_auth_sessions_credentials",
|
||||||
|
keyvalues={
|
||||||
|
"session_id": session2,
|
||||||
|
"stage_type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
},
|
||||||
|
retcol="result",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(db_to_json(result2), token)
|
||||||
|
|
||||||
|
# Delete both sessions (mimics expiry)
|
||||||
|
self.get_success(
|
||||||
|
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check pending is now 0
|
||||||
|
pending = self.get_success(
|
||||||
|
store.db_pool.simple_select_one_onecol(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcol="pending",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(pending, 0)
|
||||||
|
|
||||||
|
@override_config({"registration_requires_token": True})
|
||||||
|
def test_POST_registration_token_session_expiry_deleted_token(self):
|
||||||
|
"""Test session expiry doesn't break when the token is deleted.
|
||||||
|
|
||||||
|
1. Start but don't complete UIA with a registration token
|
||||||
|
2. Delete the token from the database
|
||||||
|
3. Expire the session
|
||||||
|
"""
|
||||||
|
token = "abcd"
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": None,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Do request without auth to get a session ID
|
||||||
|
params = {"username": "kermit", "password": "monkey"}
|
||||||
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
# Use token
|
||||||
|
params["auth"] = {
|
||||||
|
"type": LoginType.REGISTRATION_TOKEN,
|
||||||
|
"token": token,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
|
|
||||||
|
# Delete token
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_delete_one(
|
||||||
|
"registration_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete session (mimics expiry)
|
||||||
|
self.get_success(
|
||||||
|
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||||
|
)
|
||||||
|
|
||||||
def test_advertised_flows(self):
|
def test_advertised_flows(self):
|
||||||
channel = self.make_request(b"POST", self.url, b"{}")
|
channel = self.make_request(b"POST", self.url, b"{}")
|
||||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
|
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
|
||||||
self.assertLessEqual(res, now_ms + self.validity_period)
|
self.assertLessEqual(res, now_ms + self.validity_period)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [register.register_servlets]
|
||||||
|
url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
config["registration_requires_token"] = True
|
||||||
|
return config
|
||||||
|
|
||||||
|
def test_GET_token_valid(self):
|
||||||
|
token = "abcd"
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
"registration_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": None,
|
||||||
|
"pending": 0,
|
||||||
|
"completed": 0,
|
||||||
|
"expiry_time": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
b"GET",
|
||||||
|
f"{self.url}?token={token}",
|
||||||
|
)
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["valid"], True)
|
||||||
|
|
||||||
|
def test_GET_token_invalid(self):
|
||||||
|
token = "1234"
|
||||||
|
channel = self.make_request(
|
||||||
|
b"GET",
|
||||||
|
f"{self.url}?token={token}",
|
||||||
|
)
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEquals(channel.json_body["valid"], False)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
|
||||||
|
)
|
||||||
|
def test_GET_ratelimiting(self):
|
||||||
|
token = "1234"
|
||||||
|
|
||||||
|
for i in range(0, 6):
|
||||||
|
channel = self.make_request(
|
||||||
|
b"GET",
|
||||||
|
f"{self.url}?token={token}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if i == 5:
|
||||||
|
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||||
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||||
|
else:
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
b"GET",
|
||||||
|
f"{self.url}?token={token}",
|
||||||
|
)
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
Loading…
Reference in New Issue
Block a user