Implement MSC3231: Token authenticated registration (#10142)

Signed-off-by: Callum Brown <callum@calcuode.com>

This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
Callum Brown 2021-08-21 22:14:43 +01:00 committed by GitHub
parent ecd823d766
commit 947dbbdfd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 2389 additions and 1 deletions

View file

@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
async def registration_token_is_valid(self, token: str) -> bool:
"""Checks if a token can be used to authenticate a registration.
Args:
token: The registration token to be checked
Returns:
True if the token is valid, False otherwise.
"""
res = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
)
# Check if the token exists
if res is None:
return False
# Check if the token has expired
now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now:
return False
# Check if the token has been used up
if (
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
return False
# Otherwise, the token is valid
return True
async def set_registration_token_pending(self, token: str) -> None:
"""Increment the pending registrations counter for a token.
Args:
token: The registration token pending use
"""
def _set_registration_token_pending_txn(txn):
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues={"pending": pending + 1},
)
return await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
async def use_registration_token(self, token: str) -> None:
"""Complete a use of the given registration token.
The `pending` counter will be decremented, and the `completed`
counter will be incremented.
Args:
token: The registration token to be 'used'
"""
def _use_registration_token_txn(txn):
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
) # type: ignore
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues={
"completed": res["completed"] + 1,
"pending": res["pending"] - 1,
},
)
return await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
async def get_registration_tokens(
self, valid: Optional[bool] = None
) -> List[Dict[str, Any]]:
"""List all registration tokens. Used by the admin API.
Args:
valid: If True, only valid tokens are returned.
If False, only invalid tokens are returned.
Default is None: return all tokens regardless of validity.
Returns:
A list of dicts, each containing details of a token.
"""
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
elif valid:
# Select valid tokens only
sql = (
"SELECT * FROM registration_tokens WHERE "
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
"AND (expiry_time > ? OR expiry_time IS NULL)"
)
txn.execute(sql, [now])
else:
# Select invalid tokens only
sql = (
"SELECT * FROM registration_tokens WHERE "
"uses_allowed <= pending + completed OR expiry_time <= ?"
)
txn.execute(sql, [now])
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"select_registration_tokens",
select_registration_tokens_txn,
self._clock.time_msec(),
valid,
)
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Get info about the given registration token. Used by the admin API.
Args:
token: The token to retrieve information about.
Returns:
A dict, or None if token doesn't exist.
"""
return await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
async def generate_registration_token(
self, length: int, chars: str
) -> Optional[str]:
"""Generate a random registration token. Used by the admin API.
Args:
length: The length of the token to generate.
chars: A string of the characters allowed in the generated token.
Returns:
The generated token.
Raises:
SynapseError if a unique registration token could still not be
generated after a few tries.
"""
# Make a few attempts at generating a unique token of the required
# length before failing.
for _i in range(3):
# Generate token
token = "".join(random.choices(chars, k=length))
# Check if the token already exists
existing_token = await self.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="token",
allow_none=True,
desc="check_if_registration_token_exists",
)
if existing_token is None:
# The generated token doesn't exist yet, return it
return token
raise SynapseError(
500,
"Unable to generate a unique registration token. Try again with a greater length",
Codes.UNKNOWN,
)
async def create_registration_token(
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
) -> bool:
"""Create a new registration token. Used by the admin API.
Args:
token: The token to create.
uses_allowed: The number of times the token can be used to complete
a registration before it becomes invalid. A value of None indicates
unlimited uses.
expiry_time: The latest time the token is valid. Given as the
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
None indicates that the token does not expire.
Returns:
Whether the row was inserted or not.
"""
def _create_registration_token_txn(txn):
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["token"],
allow_none=True,
)
if row is not None:
# Token already exists
return False
self.db_pool.simple_insert_txn(
txn,
"registration_tokens",
values={
"token": token,
"uses_allowed": uses_allowed,
"pending": 0,
"completed": 0,
"expiry_time": expiry_time,
},
)
return True
return await self.db_pool.runInteraction(
"create_registration_token", _create_registration_token_txn
)
async def update_registration_token(
self, token: str, updatevalues: Dict[str, Optional[int]]
) -> Optional[Dict[str, Any]]:
"""Update a registration token. Used by the admin API.
Args:
token: The token to update.
updatevalues: A dict with the fields to update. E.g.:
`{"uses_allowed": 3}` to update just uses_allowed, or
`{"uses_allowed": 3, "expiry_time": None}` to update both.
This is passed straight to simple_update_one.
Returns:
A dict with all info about the token, or None if token doesn't exist.
"""
def _update_registration_token_txn(txn):
try:
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues=updatevalues,
)
except StoreError:
# Update failed because token does not exist
return None
# Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=[
"token",
"uses_allowed",
"pending",
"completed",
"expiry_time",
],
allow_none=True,
)
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
async def delete_registration_token(self, token: str) -> bool:
"""Delete a registration token. Used by the admin API.
Args:
token: The token to delete.
Returns:
Whether the token was successfully deleted or not.
"""
try:
await self.db_pool.simple_delete_one(
"registration_tokens",
keyvalues={"token": token},
desc="delete_registration_token",
)
except StoreError:
# Deletion failed because token does not exist
return False
return True
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""