mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-15 08:41:24 -05:00
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
This commit is contained in:
parent
763dba77ef
commit
bd4919fb72
15 changed files with 892 additions and 61 deletions
|
|
@ -30,6 +30,7 @@ from typing import (
|
|||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
|
@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
|
|||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.rest.client.v1.login import LoginResponse
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
|
|||
"params": params,
|
||||
}
|
||||
|
||||
async def refresh_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
valid_until_ms: Optional[int],
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Consumes a refresh token and generate both a new access token and a new refresh token from it.
|
||||
|
||||
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
|
||||
|
||||
Args:
|
||||
refresh_token: The token to consume.
|
||||
valid_until_ms: The expiration timestamp of the new access token.
|
||||
|
||||
Returns:
|
||||
A tuple containing the new access token and refresh token
|
||||
"""
|
||||
|
||||
# Verify the token signature first before looking up the token
|
||||
if not self._verify_refresh_token(refresh_token):
|
||||
raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
|
||||
|
||||
existing_token = await self.store.lookup_refresh_token(refresh_token)
|
||||
if existing_token is None:
|
||||
raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
|
||||
|
||||
if (
|
||||
existing_token.has_next_access_token_been_used
|
||||
or existing_token.has_next_refresh_token_been_refreshed
|
||||
):
|
||||
raise SynapseError(
|
||||
403, "refresh token isn't valid anymore", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
(
|
||||
new_refresh_token,
|
||||
new_refresh_token_id,
|
||||
) = await self.get_refresh_token_for_user_id(
|
||||
user_id=existing_token.user_id, device_id=existing_token.device_id
|
||||
)
|
||||
access_token = await self.get_access_token_for_user_id(
|
||||
user_id=existing_token.user_id,
|
||||
device_id=existing_token.device_id,
|
||||
valid_until_ms=valid_until_ms,
|
||||
refresh_token_id=new_refresh_token_id,
|
||||
)
|
||||
await self.store.replace_refresh_token(
|
||||
existing_token.token_id, new_refresh_token_id
|
||||
)
|
||||
return access_token, new_refresh_token
|
||||
|
||||
def _verify_refresh_token(self, token: str) -> bool:
|
||||
"""
|
||||
Verifies the shape of a refresh token.
|
||||
|
||||
Args:
|
||||
token: The refresh token to verify
|
||||
|
||||
Returns:
|
||||
Whether the token has the right shape
|
||||
"""
|
||||
parts = token.split("_", maxsplit=4)
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
|
||||
type, localpart, rand, crc = parts
|
||||
|
||||
# Refresh tokens are prefixed by "syr_", let's check that
|
||||
if type != "syr":
|
||||
return False
|
||||
|
||||
# Check the CRC
|
||||
base = f"{type}_{localpart}_{rand}"
|
||||
expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
if crc != expected_crc:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def get_refresh_token_for_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Creates a new refresh token for the user with the given user ID.
|
||||
|
||||
Args:
|
||||
user_id: canonical user ID
|
||||
device_id: the device ID to associate with the token.
|
||||
|
||||
Returns:
|
||||
The newly created refresh token and its ID in the database
|
||||
"""
|
||||
refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
|
||||
refresh_token_id = await self.store.add_refresh_token_to_user(
|
||||
user_id=user_id,
|
||||
token=refresh_token,
|
||||
device_id=device_id,
|
||||
)
|
||||
return refresh_token, refresh_token_id
|
||||
|
||||
async def get_access_token_for_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
|
|||
valid_until_ms: Optional[int],
|
||||
puppets_user_id: Optional[str] = None,
|
||||
is_appservice_ghost: bool = False,
|
||||
refresh_token_id: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a new access token for the user with the given user ID.
|
||||
|
|
@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
|
|||
valid_until_ms: when the token is valid until. None for
|
||||
no expiry.
|
||||
is_appservice_ghost: Whether the user is an application ghost user
|
||||
refresh_token_id: the refresh token ID that will be associated with
|
||||
this access token.
|
||||
Returns:
|
||||
The access token for the user's session.
|
||||
Raises:
|
||||
|
|
@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
|
|||
device_id=device_id,
|
||||
valid_until_ms=valid_until_ms,
|
||||
puppets_user_id=puppets_user_id,
|
||||
refresh_token_id=refresh_token_id,
|
||||
)
|
||||
|
||||
# the device *should* have been registered before we got here; however,
|
||||
|
|
@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
|
|||
self,
|
||||
login_submission: Dict[str, Any],
|
||||
ratelimit: bool = False,
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate auth types which don't
|
||||
|
|
@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
|
|||
self,
|
||||
username: str,
|
||||
login_submission: Dict[str, Any],
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
|
||||
"""Helper for validate_login
|
||||
|
||||
Handles login, once we've mapped 3pids onto userids
|
||||
|
|
@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
async def check_password_provider_3pid(
|
||||
self, medium: str, address: str, password: str
|
||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
|
||||
"""Check if a password provider is able to validate a thirdparty login
|
||||
|
||||
Args:
|
||||
|
|
@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
|
|||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
return f"{base}_{crc}"
|
||||
|
||||
def generate_refresh_token(self, for_user: UserID) -> str:
|
||||
"""Generates an opaque string, for use as a refresh token"""
|
||||
|
||||
# we use the following format for refresh tokens:
|
||||
# syr_<base64 local part>_<random string>_<base62 crc check>
|
||||
|
||||
b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
|
||||
random_string = stringutils.random_string(20)
|
||||
base = f"syr_{b64local}_{random_string}"
|
||||
|
||||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
return f"{base}_{crc}"
|
||||
|
||||
async def validate_short_term_login_token(
|
||||
self, login_token: str
|
||||
) -> LoginTokenAttributes:
|
||||
|
|
@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
respond_with_html(request, 200, html)
|
||||
|
||||
async def _sso_login_callback(self, login_result: JsonDict) -> None:
|
||||
async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
|
||||
"""
|
||||
A login callback which might add additional attributes to the login response.
|
||||
|
||||
|
|
@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
extra_attributes = self._extra_attributes.get(login_result["user_id"])
|
||||
if extra_attributes:
|
||||
login_result.update(extra_attributes.extra_attributes)
|
||||
login_result_dict = cast(Dict[str, Any], login_result)
|
||||
login_result_dict.update(extra_attributes.extra_attributes)
|
||||
|
||||
def _expire_sso_extra_attributes(self) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@
|
|||
"""Contains functions for registering clients."""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
from prometheus_client import Counter
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
||||
|
|
@ -54,6 +55,16 @@ login_counter = Counter(
|
|||
["guest", "auth_provider"],
|
||||
)
|
||||
|
||||
LoginDict = TypedDict(
|
||||
"LoginDict",
|
||||
{
|
||||
"device_id": str,
|
||||
"access_token": str,
|
||||
"valid_until_ms": Optional[int],
|
||||
"refresh_token": Optional[str],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class RegistrationHandler(BaseHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
|
@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler):
|
|||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
self.session_lifetime = hs.config.session_lifetime
|
||||
self.access_token_lifetime = hs.config.access_token_lifetime
|
||||
|
||||
async def check_username(
|
||||
self,
|
||||
|
|
@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler):
|
|||
is_guest: bool = False,
|
||||
is_appservice_ghost: bool = False,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Tuple[str, str]:
|
||||
should_issue_refresh_token: bool = False,
|
||||
) -> Tuple[str, str, Optional[int], Optional[str]]:
|
||||
"""Register a device for a user and generate an access token.
|
||||
|
||||
The access token will be limited by the homeserver's session_lifetime config.
|
||||
|
|
@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler):
|
|||
is_guest: Whether this is a guest account
|
||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||
prometheus metrics).
|
||||
should_issue_refresh_token: Whether it should also issue a refresh token
|
||||
Returns:
|
||||
Tuple of device ID and access token
|
||||
Tuple of device ID, access token, access token expiration time and refresh token
|
||||
"""
|
||||
res = await self._register_device_client(
|
||||
user_id=user_id,
|
||||
|
|
@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler):
|
|||
initial_display_name=initial_display_name,
|
||||
is_guest=is_guest,
|
||||
is_appservice_ghost=is_appservice_ghost,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
|
||||
login_counter.labels(
|
||||
|
|
@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler):
|
|||
auth_provider=(auth_provider_id or ""),
|
||||
).inc()
|
||||
|
||||
return res["device_id"], res["access_token"]
|
||||
return (
|
||||
res["device_id"],
|
||||
res["access_token"],
|
||||
res["valid_until_ms"],
|
||||
res["refresh_token"],
|
||||
)
|
||||
|
||||
async def register_device_inner(
|
||||
self,
|
||||
|
|
@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler):
|
|||
initial_display_name: Optional[str],
|
||||
is_guest: bool = False,
|
||||
is_appservice_ghost: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
should_issue_refresh_token: bool = False,
|
||||
) -> LoginDict:
|
||||
"""Helper for register_device
|
||||
|
||||
Does the bits that need doing on the main process. Not for use outside this
|
||||
|
|
@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler):
|
|||
)
|
||||
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
||||
|
||||
refresh_token = None
|
||||
refresh_token_id = None
|
||||
|
||||
registered_device_id = await self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
|
@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler):
|
|||
assert valid_until_ms is None
|
||||
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
|
||||
else:
|
||||
if should_issue_refresh_token:
|
||||
(
|
||||
refresh_token,
|
||||
refresh_token_id,
|
||||
) = await self._auth_handler.get_refresh_token_for_user_id(
|
||||
user_id,
|
||||
device_id=registered_device_id,
|
||||
)
|
||||
valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
|
||||
|
||||
access_token = await self._auth_handler.get_access_token_for_user_id(
|
||||
user_id,
|
||||
device_id=registered_device_id,
|
||||
valid_until_ms=valid_until_ms,
|
||||
is_appservice_ghost=is_appservice_ghost,
|
||||
refresh_token_id=refresh_token_id,
|
||||
)
|
||||
|
||||
return {"device_id": registered_device_id, "access_token": access_token}
|
||||
return {
|
||||
"device_id": registered_device_id,
|
||||
"access_token": access_token,
|
||||
"valid_until_ms": valid_until_ms,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
async def post_registration_actions(
|
||||
self, user_id: str, auth_result: dict, access_token: Optional[str]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue