Save login tokens in database (#13844)

* Save login tokens in database

Signed-off-by: Quentin Gliech <quenting@element.io>

* Add upgrade notes

* Track login token reuse in a Prometheus metric

Signed-off-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
Quentin Gliech 2022-10-26 12:45:41 +02:00 committed by GitHub
parent d902181de9
commit 8756d5c87e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 338 additions and 228 deletions

View file

@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.api.errors import (
Codes,
NotFoundError,
StoreError,
SynapseError,
ThreepidValidationError,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
because this external id is given to an other user."""
class LoginTokenExpired(Exception):
"""Exception if the login token sent expired"""
class LoginTokenReused(Exception):
"""Exception if the login token sent was already used"""
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
If None, the session can be refreshed indefinitely."""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class LoginTokenLookupResult:
"""Result of looking up a login token."""
user_id: str
"""The user this token belongs to."""
auth_provider_id: Optional[str]
"""The SSO Identity Provider that the user authenticated with, to get this token."""
auth_provider_session_id: Optional[str]
"""The session ID advertised by the SSO Identity Provider."""
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
async def add_login_token_to_user(
self,
user_id: str,
token: str,
expiry_ts: int,
auth_provider_id: Optional[str],
auth_provider_session_id: Optional[str],
) -> None:
"""Adds a short-term login token for the given user.
Args:
user_id: The user ID.
token: The new login token to add.
expiry_ts (milliseconds since the epoch): Time after which the login token
cannot be used.
auth_provider_id: The SSO Identity Provider that the user authenticated with
to get this token, if any
auth_provider_session_id: The session ID advertised by the SSO Identity
Provider, if any.
"""
await self.db_pool.simple_insert(
"login_tokens",
{
"token": token,
"user_id": user_id,
"expiry_ts": expiry_ts,
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
desc="add_login_token_to_user",
)
def _consume_login_token(
self,
txn: LoggingTransaction,
token: str,
ts: int,
) -> LoginTokenLookupResult:
values = self.db_pool.simple_select_one_txn(
txn,
"login_tokens",
keyvalues={"token": token},
retcols=(
"user_id",
"expiry_ts",
"used_ts",
"auth_provider_id",
"auth_provider_session_id",
),
allow_none=True,
)
if values is None:
raise NotFoundError()
self.db_pool.simple_update_one_txn(
txn,
"login_tokens",
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
user_id = values["user_id"]
expiry_ts = values["expiry_ts"]
used_ts = values["used_ts"]
auth_provider_id = values["auth_provider_id"]
auth_provider_session_id = values["auth_provider_session_id"]
# Token was already used
if used_ts is not None:
raise LoginTokenReused()
# Token expired
if ts > int(expiry_ts):
raise LoginTokenExpired()
return LoginTokenLookupResult(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
"""Lookup a login token and consume it.
Args:
token: The login token.
Returns:
The data stored with that token, including the `user_id`. Returns `None` if
the token does not exist or if it expired.
Raises:
NotFound if the login token was not found in database
LoginTokenExpired if the login token expired
LoginTokenReused if the login token was already used
"""
return await self.db_pool.runInteraction(
"consume_login_token",
self._consume_login_token,
token,
self._clock.time_msec(),
)
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
# Create a background job for removing expired login tokens
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
)
async def add_access_token_to_user(
self,
user_id: str,
@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
approved,
)
@wrap_as_background_process("delete_expired_login_tokens")
async def _delete_expired_login_tokens(self) -> None:
"""Remove login tokens with expiry dates that have passed."""
def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
txn.execute(sql, (ts,))
# We keep the expired tokens for an extra 5 minutes so we can measure how many
# times a token is being used after its expiry
now = self._clock.time_msec()
await self.db_pool.runInteraction(
"delete_expired_login_tokens",
_delete_expired_login_tokens_txn,
now - (5 * 60 * 1000),
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""