mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-07 16:35:05 -04:00
Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com> This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
parent
ecd823d766
commit
947dbbdfd1
21 changed files with 2389 additions and 1 deletions
|
@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="update_access_token_last_validated",
|
||||
)
|
||||
|
||||
async def registration_token_is_valid(self, token: str) -> bool:
|
||||
"""Checks if a token can be used to authenticate a registration.
|
||||
|
||||
Args:
|
||||
token: The registration token to be checked
|
||||
Returns:
|
||||
True if the token is valid, False otherwise.
|
||||
"""
|
||||
res = await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
# Check if the token exists
|
||||
if res is None:
|
||||
return False
|
||||
|
||||
# Check if the token has expired
|
||||
now = self._clock.time_msec()
|
||||
if res["expiry_time"] and res["expiry_time"] < now:
|
||||
return False
|
||||
|
||||
# Check if the token has been used up
|
||||
if (
|
||||
res["uses_allowed"]
|
||||
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
||||
):
|
||||
return False
|
||||
|
||||
# Otherwise, the token is valid
|
||||
return True
|
||||
|
||||
async def set_registration_token_pending(self, token: str) -> None:
|
||||
"""Increment the pending registrations counter for a token.
|
||||
|
||||
Args:
|
||||
token: The registration token pending use
|
||||
"""
|
||||
|
||||
def _set_registration_token_pending_txn(txn):
|
||||
pending = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"pending": pending + 1},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_registration_token_pending", _set_registration_token_pending_txn
|
||||
)
|
||||
|
||||
async def use_registration_token(self, token: str) -> None:
|
||||
"""Complete a use of the given registration token.
|
||||
|
||||
The `pending` counter will be decremented, and the `completed`
|
||||
counter will be incremented.
|
||||
|
||||
Args:
|
||||
token: The registration token to be 'used'
|
||||
"""
|
||||
|
||||
def _use_registration_token_txn(txn):
|
||||
# Normally, res is Optional[Dict[str, Any]].
|
||||
# Override type because the return type is only optional if
|
||||
# allow_none is True, and we don't want mypy throwing errors
|
||||
# about None not being indexable.
|
||||
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
) # type: ignore
|
||||
|
||||
# Decrement pending and increment completed
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={
|
||||
"completed": res["completed"] + 1,
|
||||
"pending": res["pending"] - 1,
|
||||
},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"use_registration_token", _use_registration_token_txn
|
||||
)
|
||||
|
||||
async def get_registration_tokens(
|
||||
self, valid: Optional[bool] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List all registration tokens. Used by the admin API.
|
||||
|
||||
Args:
|
||||
valid: If True, only valid tokens are returned.
|
||||
If False, only invalid tokens are returned.
|
||||
Default is None: return all tokens regardless of validity.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each containing details of a token.
|
||||
"""
|
||||
|
||||
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
|
||||
if valid is None:
|
||||
# Return all tokens regardless of validity
|
||||
txn.execute("SELECT * FROM registration_tokens")
|
||||
|
||||
elif valid:
|
||||
# Select valid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
|
||||
"AND (expiry_time > ? OR expiry_time IS NULL)"
|
||||
)
|
||||
txn.execute(sql, [now])
|
||||
|
||||
else:
|
||||
# Select invalid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"uses_allowed <= pending + completed OR expiry_time <= ?"
|
||||
)
|
||||
txn.execute(sql, [now])
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"select_registration_tokens",
|
||||
select_registration_tokens_txn,
|
||||
self._clock.time_msec(),
|
||||
valid,
|
||||
)
|
||||
|
||||
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get info about the given registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to retrieve information about.
|
||||
|
||||
Returns:
|
||||
A dict, or None if token doesn't exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
desc="get_one_registration_token",
|
||||
)
|
||||
|
||||
async def generate_registration_token(
|
||||
self, length: int, chars: str
|
||||
) -> Optional[str]:
|
||||
"""Generate a random registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
length: The length of the token to generate.
|
||||
chars: A string of the characters allowed in the generated token.
|
||||
|
||||
Returns:
|
||||
The generated token.
|
||||
|
||||
Raises:
|
||||
SynapseError if a unique registration token could still not be
|
||||
generated after a few tries.
|
||||
"""
|
||||
# Make a few attempts at generating a unique token of the required
|
||||
# length before failing.
|
||||
for _i in range(3):
|
||||
# Generate token
|
||||
token = "".join(random.choices(chars, k=length))
|
||||
|
||||
# Check if the token already exists
|
||||
existing_token = await self.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="token",
|
||||
allow_none=True,
|
||||
desc="check_if_registration_token_exists",
|
||||
)
|
||||
|
||||
if existing_token is None:
|
||||
# The generated token doesn't exist yet, return it
|
||||
return token
|
||||
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Unable to generate a unique registration token. Try again with a greater length",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
async def create_registration_token(
|
||||
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
|
||||
) -> bool:
|
||||
"""Create a new registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to create.
|
||||
uses_allowed: The number of times the token can be used to complete
|
||||
a registration before it becomes invalid. A value of None indicates
|
||||
unlimited uses.
|
||||
expiry_time: The latest time the token is valid. Given as the
|
||||
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
|
||||
None indicates that the token does not expire.
|
||||
|
||||
Returns:
|
||||
Whether the row was inserted or not.
|
||||
"""
|
||||
|
||||
def _create_registration_token_txn(txn):
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if row is not None:
|
||||
# Token already exists
|
||||
return False
|
||||
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
values={
|
||||
"token": token,
|
||||
"uses_allowed": uses_allowed,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": expiry_time,
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"create_registration_token", _create_registration_token_txn
|
||||
)
|
||||
|
||||
async def update_registration_token(
|
||||
self, token: str, updatevalues: Dict[str, Optional[int]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update a registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to update.
|
||||
updatevalues: A dict with the fields to update. E.g.:
|
||||
`{"uses_allowed": 3}` to update just uses_allowed, or
|
||||
`{"uses_allowed": 3, "expiry_time": None}` to update both.
|
||||
This is passed straight to simple_update_one.
|
||||
|
||||
Returns:
|
||||
A dict with all info about the token, or None if token doesn't exist.
|
||||
"""
|
||||
|
||||
def _update_registration_token_txn(txn):
|
||||
try:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues=updatevalues,
|
||||
)
|
||||
except StoreError:
|
||||
# Update failed because token does not exist
|
||||
return None
|
||||
|
||||
# Get all info about the token so it can be sent in the response
|
||||
return self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=[
|
||||
"token",
|
||||
"uses_allowed",
|
||||
"pending",
|
||||
"completed",
|
||||
"expiry_time",
|
||||
],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_registration_token", _update_registration_token_txn
|
||||
)
|
||||
|
||||
async def delete_registration_token(self, token: str) -> bool:
|
||||
"""Delete a registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to delete.
|
||||
|
||||
Returns:
|
||||
Whether the token was successfully deleted or not.
|
||||
"""
|
||||
try:
|
||||
await self.db_pool.simple_delete_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
desc="delete_registration_token",
|
||||
)
|
||||
except StoreError:
|
||||
# Deletion failed because token does not exist
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@cached()
|
||||
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue