mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 18:14:49 -04:00
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
This commit is contained in:
parent
763dba77ef
commit
bd4919fb72
15 changed files with 892 additions and 61 deletions
|
@ -53,6 +53,9 @@ class TokenLookupResult:
|
|||
valid_until_ms: The timestamp the token expires, if any.
|
||||
token_owner: The "owner" of the token. This is either the same as the
|
||||
user, or a server admin who is logged in as the user.
|
||||
token_used: True if this token was used at least once in a request.
|
||||
This field can be out of date since `get_user_by_access_token` is
|
||||
cached.
|
||||
"""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
|
@ -62,6 +65,7 @@ class TokenLookupResult:
|
|||
device_id = attr.ib(type=Optional[str], default=None)
|
||||
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||
token_owner = attr.ib(type=str)
|
||||
token_used = attr.ib(type=bool, default=False)
|
||||
|
||||
# Make the token owner default to the user ID, which is the common case.
|
||||
@token_owner.default
|
||||
|
@ -69,6 +73,29 @@ class TokenLookupResult:
|
|||
return self.user_id
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class RefreshTokenLookupResult:
|
||||
"""Result of looking up a refresh token."""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
"""The user this token belongs to."""
|
||||
|
||||
device_id = attr.ib(type=str)
|
||||
"""The device associated with this refresh token."""
|
||||
|
||||
token_id = attr.ib(type=int)
|
||||
"""The ID of this refresh token."""
|
||||
|
||||
next_token_id = attr.ib(type=Optional[int])
|
||||
"""The ID of the refresh token which replaced this one."""
|
||||
|
||||
has_next_refresh_token_been_refreshed = attr.ib(type=bool)
|
||||
"""True if the next refresh token was used for another refresh."""
|
||||
|
||||
has_next_access_token_been_used = attr.ib(type=bool)
|
||||
"""True if the next access token was already used at least once."""
|
||||
|
||||
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
access_tokens.id as token_id,
|
||||
access_tokens.device_id,
|
||||
access_tokens.valid_until_ms,
|
||||
access_tokens.user_id as token_owner
|
||||
access_tokens.user_id as token_owner,
|
||||
access_tokens.used as token_used
|
||||
FROM users
|
||||
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
||||
WHERE token = ?
|
||||
|
@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
if rows:
|
||||
return TokenLookupResult(**rows[0])
|
||||
row = rows[0]
|
||||
|
||||
# This field is nullable, ensure it comes out as a boolean
|
||||
if row["token_used"] is None:
|
||||
row["token_used"] = False
|
||||
|
||||
return TokenLookupResult(**row)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="update_access_token_last_validated",
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||
"""
|
||||
Mark the access token as used, which invalidates the refresh token used
|
||||
to obtain it.
|
||||
|
||||
Because get_user_by_access_token is cached, this function might be
|
||||
called multiple times for the same token, effectively doing unnecessary
|
||||
SQL updates. Because updating the `used` field only goes one way (from
|
||||
False to True) it is safe to cache this function as well to avoid this
|
||||
issue.
|
||||
|
||||
Args:
|
||||
token_id: The ID of the access token to update.
|
||||
Raises:
|
||||
StoreError if there was a problem updating this.
|
||||
"""
|
||||
await self.db_pool.simple_update_one(
|
||||
"access_tokens",
|
||||
{"id": token_id},
|
||||
{"used": True},
|
||||
desc="mark_access_token_as_used",
|
||||
)
|
||||
|
||||
async def lookup_refresh_token(
|
||||
self, token: str
|
||||
) -> Optional[RefreshTokenLookupResult]:
|
||||
"""Lookup a refresh token with hints about its validity."""
|
||||
|
||||
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
rt.id token_id,
|
||||
rt.user_id,
|
||||
rt.device_id,
|
||||
rt.next_token_id,
|
||||
(nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
|
||||
at.used has_next_access_token_been_used
|
||||
FROM refresh_tokens rt
|
||||
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
|
||||
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
|
||||
WHERE rt.token = ?
|
||||
""",
|
||||
(token,),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return RefreshTokenLookupResult(
|
||||
token_id=row[0],
|
||||
user_id=row[1],
|
||||
device_id=row[2],
|
||||
next_token_id=row[3],
|
||||
has_next_refresh_token_been_refreshed=row[4],
|
||||
# This column is nullable, ensure it's a boolean
|
||||
has_next_access_token_been_used=(row[5] or False),
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"lookup_refresh_token", _lookup_refresh_token_txn
|
||||
)
|
||||
|
||||
async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
|
||||
"""
|
||||
Set the successor of a refresh token, removing the existing successor
|
||||
if any.
|
||||
|
||||
Args:
|
||||
token_id: ID of the refresh token to update.
|
||||
next_token_id: ID of its successor.
|
||||
"""
|
||||
|
||||
def _replace_refresh_token_txn(txn) -> None:
|
||||
# First check if there was an existing refresh token
|
||||
old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"refresh_tokens",
|
||||
{"id": token_id},
|
||||
"next_token_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"refresh_tokens",
|
||||
{"id": token_id},
|
||||
{"next_token_id": next_token_id},
|
||||
)
|
||||
|
||||
# Delete the old "next" token if it exists. This should cascade and
|
||||
# delete the associated access_token
|
||||
if old_next_token_id is not None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn,
|
||||
"refresh_tokens",
|
||||
{"id": old_next_token_id},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"replace_refresh_token", _replace_refresh_token_txn
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
def __init__(
|
||||
|
@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||
|
||||
async def add_access_token_to_user(
|
||||
self,
|
||||
|
@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
device_id: Optional[str],
|
||||
valid_until_ms: Optional[int],
|
||||
puppets_user_id: Optional[str] = None,
|
||||
refresh_token_id: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Adds an access token for the given user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
token: The new access token to add.
|
||||
device_id: ID of the device to associate with the access token
|
||||
device_id: ID of the device to associate with the access token.
|
||||
valid_until_ms: when the token is valid until. None for no expiry.
|
||||
puppets_user_id
|
||||
refresh_token_id: ID of the refresh token generated alongside this
|
||||
access token.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
Returns:
|
||||
|
@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
"valid_until_ms": valid_until_ms,
|
||||
"puppets_user_id": puppets_user_id,
|
||||
"last_validated": now,
|
||||
"refresh_token_id": refresh_token_id,
|
||||
"used": False,
|
||||
},
|
||||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
return next_id
|
||||
|
||||
async def add_refresh_token_to_user(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
device_id: Optional[str],
|
||||
) -> int:
|
||||
"""Adds a refresh token for the given user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
token: The new access token to add.
|
||||
device_id: ID of the device to associate with the refresh token.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
Returns:
|
||||
The token ID
|
||||
"""
|
||||
next_id = self._refresh_tokens_id_gen.get_next()
|
||||
|
||||
await self.db_pool.simple_insert(
|
||||
"refresh_tokens",
|
||||
{
|
||||
"id": next_id,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"token": token,
|
||||
"next_token_id": None,
|
||||
},
|
||||
desc="add_refresh_token_to_user",
|
||||
)
|
||||
|
||||
return next_id
|
||||
|
||||
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn, "access_tokens", {"token": token}, "device_id"
|
||||
|
@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
device_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access tokens belonging to a user
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
|
@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values = [v for _, v in items] # type: List[Union[str, int]]
|
||||
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
||||
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
||||
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
||||
refresh_where_clause = where_clause
|
||||
refresh_values = values.copy()
|
||||
if except_token_id:
|
||||
# TODO: support that for refresh tokens
|
||||
where_clause += " AND id != ?"
|
||||
values.append(except_token_id)
|
||||
|
||||
|
@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
|
||||
refresh_values,
|
||||
)
|
||||
|
||||
return tokens_and_devices
|
||||
|
||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
async def delete_refresh_token(self, refresh_token: str) -> None:
|
||||
def f(txn):
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_refresh_token", f)
|
||||
|
||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||
"""
|
||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue