mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-03 18:40:51 -05:00
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
This commit is contained in:
parent
763dba77ef
commit
bd4919fb72
1
changelog.d/9450.feature
Normal file
1
changelog.d/9450.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Implement refresh tokens as specified by [MSC2918](https://github.com/matrix-org/matrix-doc/pull/2918).
|
@ -93,6 +93,7 @@ BOOLEAN_COLUMNS = {
|
|||||||
"local_media_repository": ["safe_from_quarantine"],
|
"local_media_repository": ["safe_from_quarantine"],
|
||||||
"users": ["shadow_banned"],
|
"users": ["shadow_banned"],
|
||||||
"e2e_fallback_keys_json": ["used"],
|
"e2e_fallback_keys_json": ["used"],
|
||||||
|
"access_tokens": ["used"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -307,7 +308,8 @@ class Porter(object):
|
|||||||
information_schema.table_constraints AS tc
|
information_schema.table_constraints AS tc
|
||||||
INNER JOIN information_schema.constraint_column_usage AS ccu
|
INNER JOIN information_schema.constraint_column_usage AS ccu
|
||||||
USING (table_schema, constraint_name)
|
USING (table_schema, constraint_name)
|
||||||
WHERE tc.constraint_type = 'FOREIGN KEY';
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||||
|
AND tc.table_name != ccu.table_name;
|
||||||
"""
|
"""
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
|
|
||||||
|
@ -245,6 +245,11 @@ class Auth:
|
|||||||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mark the token as used. This is used to invalidate old refresh
|
||||||
|
# tokens after some time.
|
||||||
|
if not user_info.token_used and token_id is not None:
|
||||||
|
await self.store.mark_access_token_as_used(token_id)
|
||||||
|
|
||||||
requester = create_requester(
|
requester = create_requester(
|
||||||
user_info.user_id,
|
user_info.user_id,
|
||||||
token_id,
|
token_id,
|
||||||
|
@ -119,6 +119,27 @@ class RegistrationConfig(Config):
|
|||||||
session_lifetime = self.parse_duration(session_lifetime)
|
session_lifetime = self.parse_duration(session_lifetime)
|
||||||
self.session_lifetime = session_lifetime
|
self.session_lifetime = session_lifetime
|
||||||
|
|
||||||
|
# The `access_token_lifetime` applies for tokens that can be renewed
|
||||||
|
# using a refresh token, as per MSC2918. If it is `None`, the refresh
|
||||||
|
# token mechanism is disabled.
|
||||||
|
#
|
||||||
|
# Since it is incompatible with the `session_lifetime` mechanism, it is set to
|
||||||
|
# `None` by default if a `session_lifetime` is set.
|
||||||
|
access_token_lifetime = config.get(
|
||||||
|
"access_token_lifetime", "5m" if session_lifetime is None else None
|
||||||
|
)
|
||||||
|
if access_token_lifetime is not None:
|
||||||
|
access_token_lifetime = self.parse_duration(access_token_lifetime)
|
||||||
|
self.access_token_lifetime = access_token_lifetime
|
||||||
|
|
||||||
|
if session_lifetime is not None and access_token_lifetime is not None:
|
||||||
|
raise ConfigError(
|
||||||
|
"The refresh token mechanism is incompatible with the "
|
||||||
|
"`session_lifetime` option. Consider disabling the "
|
||||||
|
"`session_lifetime` option or disabling the refresh token "
|
||||||
|
"mechanism by removing the `access_token_lifetime` option."
|
||||||
|
)
|
||||||
|
|
||||||
# The success template used during fallback auth.
|
# The success template used during fallback auth.
|
||||||
self.fallback_success_template = self.read_template("auth_success.html")
|
self.fallback_success_template = self.read_template("auth_success.html")
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
|
|||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from synapse.rest.client.v1.login import LoginResponse
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
|
|||||||
"params": params,
|
"params": params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def refresh_token(
|
||||||
|
self,
|
||||||
|
refresh_token: str,
|
||||||
|
valid_until_ms: Optional[int],
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Consumes a refresh token and generate both a new access token and a new refresh token from it.
|
||||||
|
|
||||||
|
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
refresh_token: The token to consume.
|
||||||
|
valid_until_ms: The expiration timestamp of the new access token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the new access token and refresh token
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Verify the token signature first before looking up the token
|
||||||
|
if not self._verify_refresh_token(refresh_token):
|
||||||
|
raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
|
||||||
|
|
||||||
|
existing_token = await self.store.lookup_refresh_token(refresh_token)
|
||||||
|
if existing_token is None:
|
||||||
|
raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
|
||||||
|
|
||||||
|
if (
|
||||||
|
existing_token.has_next_access_token_been_used
|
||||||
|
or existing_token.has_next_refresh_token_been_refreshed
|
||||||
|
):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "refresh token isn't valid anymore", Codes.FORBIDDEN
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
new_refresh_token,
|
||||||
|
new_refresh_token_id,
|
||||||
|
) = await self.get_refresh_token_for_user_id(
|
||||||
|
user_id=existing_token.user_id, device_id=existing_token.device_id
|
||||||
|
)
|
||||||
|
access_token = await self.get_access_token_for_user_id(
|
||||||
|
user_id=existing_token.user_id,
|
||||||
|
device_id=existing_token.device_id,
|
||||||
|
valid_until_ms=valid_until_ms,
|
||||||
|
refresh_token_id=new_refresh_token_id,
|
||||||
|
)
|
||||||
|
await self.store.replace_refresh_token(
|
||||||
|
existing_token.token_id, new_refresh_token_id
|
||||||
|
)
|
||||||
|
return access_token, new_refresh_token
|
||||||
|
|
||||||
|
def _verify_refresh_token(self, token: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verifies the shape of a refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The refresh token to verify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the token has the right shape
|
||||||
|
"""
|
||||||
|
parts = token.split("_", maxsplit=4)
|
||||||
|
if len(parts) != 4:
|
||||||
|
return False
|
||||||
|
|
||||||
|
type, localpart, rand, crc = parts
|
||||||
|
|
||||||
|
# Refresh tokens are prefixed by "syr_", let's check that
|
||||||
|
if type != "syr":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check the CRC
|
||||||
|
base = f"{type}_{localpart}_{rand}"
|
||||||
|
expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||||
|
if crc != expected_crc:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_refresh_token_for_user_id(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Creates a new refresh token for the user with the given user ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: canonical user ID
|
||||||
|
device_id: the device ID to associate with the token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The newly created refresh token and its ID in the database
|
||||||
|
"""
|
||||||
|
refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
|
||||||
|
refresh_token_id = await self.store.add_refresh_token_to_user(
|
||||||
|
user_id=user_id,
|
||||||
|
token=refresh_token,
|
||||||
|
device_id=device_id,
|
||||||
|
)
|
||||||
|
return refresh_token, refresh_token_id
|
||||||
|
|
||||||
async def get_access_token_for_user_id(
|
async def get_access_token_for_user_id(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
|
|||||||
valid_until_ms: Optional[int],
|
valid_until_ms: Optional[int],
|
||||||
puppets_user_id: Optional[str] = None,
|
puppets_user_id: Optional[str] = None,
|
||||||
is_appservice_ghost: bool = False,
|
is_appservice_ghost: bool = False,
|
||||||
|
refresh_token_id: Optional[int] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a new access token for the user with the given user ID.
|
Creates a new access token for the user with the given user ID.
|
||||||
@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
|
|||||||
valid_until_ms: when the token is valid until. None for
|
valid_until_ms: when the token is valid until. None for
|
||||||
no expiry.
|
no expiry.
|
||||||
is_appservice_ghost: Whether the user is an application ghost user
|
is_appservice_ghost: Whether the user is an application ghost user
|
||||||
|
refresh_token_id: the refresh token ID that will be associated with
|
||||||
|
this access token.
|
||||||
Returns:
|
Returns:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
|
|||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
valid_until_ms=valid_until_ms,
|
valid_until_ms=valid_until_ms,
|
||||||
puppets_user_id=puppets_user_id,
|
puppets_user_id=puppets_user_id,
|
||||||
|
refresh_token_id=refresh_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the device *should* have been registered before we got here; however,
|
# the device *should* have been registered before we got here; however,
|
||||||
@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
|
|||||||
self,
|
self,
|
||||||
login_submission: Dict[str, Any],
|
login_submission: Dict[str, Any],
|
||||||
ratelimit: bool = False,
|
ratelimit: bool = False,
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
|
||||||
"""Authenticates the user for the /login API
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
Also used by the user-interactive auth flow to validate auth types which don't
|
Also used by the user-interactive auth flow to validate auth types which don't
|
||||||
@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
|
|||||||
self,
|
self,
|
||||||
username: str,
|
username: str,
|
||||||
login_submission: Dict[str, Any],
|
login_submission: Dict[str, Any],
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
|
||||||
"""Helper for validate_login
|
"""Helper for validate_login
|
||||||
|
|
||||||
Handles login, once we've mapped 3pids onto userids
|
Handles login, once we've mapped 3pids onto userids
|
||||||
@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
async def check_password_provider_3pid(
|
async def check_password_provider_3pid(
|
||||||
self, medium: str, address: str, password: str
|
self, medium: str, address: str, password: str
|
||||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
|
||||||
"""Check if a password provider is able to validate a thirdparty login
|
"""Check if a password provider is able to validate a thirdparty login
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
|
|||||||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||||
return f"{base}_{crc}"
|
return f"{base}_{crc}"
|
||||||
|
|
||||||
|
def generate_refresh_token(self, for_user: UserID) -> str:
|
||||||
|
"""Generates an opaque string, for use as a refresh token"""
|
||||||
|
|
||||||
|
# we use the following format for refresh tokens:
|
||||||
|
# syr_<base64 local part>_<random string>_<base62 crc check>
|
||||||
|
|
||||||
|
b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
|
||||||
|
random_string = stringutils.random_string(20)
|
||||||
|
base = f"syr_{b64local}_{random_string}"
|
||||||
|
|
||||||
|
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||||
|
return f"{base}_{crc}"
|
||||||
|
|
||||||
async def validate_short_term_login_token(
|
async def validate_short_term_login_token(
|
||||||
self, login_token: str
|
self, login_token: str
|
||||||
) -> LoginTokenAttributes:
|
) -> LoginTokenAttributes:
|
||||||
@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
respond_with_html(request, 200, html)
|
respond_with_html(request, 200, html)
|
||||||
|
|
||||||
async def _sso_login_callback(self, login_result: JsonDict) -> None:
|
async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
|
||||||
"""
|
"""
|
||||||
A login callback which might add additional attributes to the login response.
|
A login callback which might add additional attributes to the login response.
|
||||||
|
|
||||||
@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
extra_attributes = self._extra_attributes.get(login_result["user_id"])
|
extra_attributes = self._extra_attributes.get(login_result["user_id"])
|
||||||
if extra_attributes:
|
if extra_attributes:
|
||||||
login_result.update(extra_attributes.extra_attributes)
|
login_result_dict = cast(Dict[str, Any], login_result)
|
||||||
|
login_result_dict.update(extra_attributes.extra_attributes)
|
||||||
|
|
||||||
def _expire_sso_extra_attributes(self) -> None:
|
def _expire_sso_extra_attributes(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -15,9 +15,10 @@
|
|||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
||||||
@ -54,6 +55,16 @@ login_counter = Counter(
|
|||||||
["guest", "auth_provider"],
|
["guest", "auth_provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LoginDict = TypedDict(
|
||||||
|
"LoginDict",
|
||||||
|
{
|
||||||
|
"device_id": str,
|
||||||
|
"access_token": str,
|
||||||
|
"valid_until_ms": Optional[int],
|
||||||
|
"refresh_token": Optional[str],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationHandler(BaseHandler):
|
class RegistrationHandler(BaseHandler):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
self.session_lifetime = hs.config.session_lifetime
|
self.session_lifetime = hs.config.session_lifetime
|
||||||
|
self.access_token_lifetime = hs.config.access_token_lifetime
|
||||||
|
|
||||||
async def check_username(
|
async def check_username(
|
||||||
self,
|
self,
|
||||||
@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
is_appservice_ghost: bool = False,
|
is_appservice_ghost: bool = False,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
) -> Tuple[str, str]:
|
should_issue_refresh_token: bool = False,
|
||||||
|
) -> Tuple[str, str, Optional[int], Optional[str]]:
|
||||||
"""Register a device for a user and generate an access token.
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
The access token will be limited by the homeserver's session_lifetime config.
|
The access token will be limited by the homeserver's session_lifetime config.
|
||||||
@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
is_guest: Whether this is a guest account
|
is_guest: Whether this is a guest account
|
||||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||||
prometheus metrics).
|
prometheus metrics).
|
||||||
|
should_issue_refresh_token: Whether it should also issue a refresh token
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of device ID and access token
|
Tuple of device ID, access token, access token expiration time and refresh token
|
||||||
"""
|
"""
|
||||||
res = await self._register_device_client(
|
res = await self._register_device_client(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
initial_display_name=initial_display_name,
|
initial_display_name=initial_display_name,
|
||||||
is_guest=is_guest,
|
is_guest=is_guest,
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
login_counter.labels(
|
login_counter.labels(
|
||||||
@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler):
|
|||||||
auth_provider=(auth_provider_id or ""),
|
auth_provider=(auth_provider_id or ""),
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
return res["device_id"], res["access_token"]
|
return (
|
||||||
|
res["device_id"],
|
||||||
|
res["access_token"],
|
||||||
|
res["valid_until_ms"],
|
||||||
|
res["refresh_token"],
|
||||||
|
)
|
||||||
|
|
||||||
async def register_device_inner(
|
async def register_device_inner(
|
||||||
self,
|
self,
|
||||||
@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
initial_display_name: Optional[str],
|
initial_display_name: Optional[str],
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
is_appservice_ghost: bool = False,
|
is_appservice_ghost: bool = False,
|
||||||
) -> Dict[str, str]:
|
should_issue_refresh_token: bool = False,
|
||||||
|
) -> LoginDict:
|
||||||
"""Helper for register_device
|
"""Helper for register_device
|
||||||
|
|
||||||
Does the bits that need doing on the main process. Not for use outside this
|
Does the bits that need doing on the main process. Not for use outside this
|
||||||
@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
||||||
|
|
||||||
|
refresh_token = None
|
||||||
|
refresh_token_id = None
|
||||||
|
|
||||||
registered_device_id = await self.device_handler.check_device_registered(
|
registered_device_id = await self.device_handler.check_device_registered(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler):
|
|||||||
assert valid_until_ms is None
|
assert valid_until_ms is None
|
||||||
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
|
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
|
||||||
else:
|
else:
|
||||||
|
if should_issue_refresh_token:
|
||||||
|
(
|
||||||
|
refresh_token,
|
||||||
|
refresh_token_id,
|
||||||
|
) = await self._auth_handler.get_refresh_token_for_user_id(
|
||||||
|
user_id,
|
||||||
|
device_id=registered_device_id,
|
||||||
|
)
|
||||||
|
valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
|
||||||
|
|
||||||
access_token = await self._auth_handler.get_access_token_for_user_id(
|
access_token = await self._auth_handler.get_access_token_for_user_id(
|
||||||
user_id,
|
user_id,
|
||||||
device_id=registered_device_id,
|
device_id=registered_device_id,
|
||||||
valid_until_ms=valid_until_ms,
|
valid_until_ms=valid_until_ms,
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
|
refresh_token_id=refresh_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"device_id": registered_device_id, "access_token": access_token}
|
return {
|
||||||
|
"device_id": registered_device_id,
|
||||||
|
"access_token": access_token,
|
||||||
|
"valid_until_ms": valid_until_ms,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
async def post_registration_actions(
|
async def post_registration_actions(
|
||||||
self, user_id: str, auth_result: dict, access_token: Optional[str]
|
self, user_id: str, auth_result: dict, access_token: Optional[str]
|
||||||
|
@ -168,7 +168,7 @@ class ModuleApi:
|
|||||||
"Using deprecated ModuleApi.register which creates a dummy user device."
|
"Using deprecated ModuleApi.register which creates a dummy user device."
|
||||||
)
|
)
|
||||||
user_id = yield self.register_user(localpart, displayname, emails or [])
|
user_id = yield self.register_user(localpart, displayname, emails or [])
|
||||||
_, access_token = yield self.register_device(user_id)
|
_, access_token, _, _ = yield self.register_device(user_id)
|
||||||
return user_id, access_token
|
return user_id, access_token
|
||||||
|
|
||||||
def register_user(
|
def register_user(
|
||||||
|
@ -36,20 +36,29 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(
|
async def _serialize_payload(
|
||||||
user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
|
user_id,
|
||||||
|
device_id,
|
||||||
|
initial_display_name,
|
||||||
|
is_guest,
|
||||||
|
is_appservice_ghost,
|
||||||
|
should_issue_refresh_token,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
user_id (int)
|
||||||
device_id (str|None): Device ID to use, if None a new one is
|
device_id (str|None): Device ID to use, if None a new one is
|
||||||
generated.
|
generated.
|
||||||
initial_display_name (str|None)
|
initial_display_name (str|None)
|
||||||
is_guest (bool)
|
is_guest (bool)
|
||||||
|
is_appservice_ghost (bool)
|
||||||
|
should_issue_refresh_token (bool)
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"initial_display_name": initial_display_name,
|
"initial_display_name": initial_display_name,
|
||||||
"is_guest": is_guest,
|
"is_guest": is_guest,
|
||||||
"is_appservice_ghost": is_appservice_ghost,
|
"is_appservice_ghost": is_appservice_ghost,
|
||||||
|
"should_issue_refresh_token": should_issue_refresh_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
@ -59,6 +68,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
initial_display_name = content["initial_display_name"]
|
initial_display_name = content["initial_display_name"]
|
||||||
is_guest = content["is_guest"]
|
is_guest = content["is_guest"]
|
||||||
is_appservice_ghost = content["is_appservice_ghost"]
|
is_appservice_ghost = content["is_appservice_ghost"]
|
||||||
|
should_issue_refresh_token = content["should_issue_refresh_token"]
|
||||||
|
|
||||||
res = await self.registration_handler.register_device_inner(
|
res = await self.registration_handler.register_device_inner(
|
||||||
user_id,
|
user_id,
|
||||||
@ -66,6 +76,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
initial_display_name,
|
initial_display_name,
|
||||||
is_guest,
|
is_guest,
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, res
|
return 200, res
|
||||||
|
@ -14,7 +14,9 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
@ -25,6 +27,8 @@ from synapse.http import get_request_uri
|
|||||||
from synapse.http.server import HttpServer, finish_request
|
from synapse.http.server import HttpServer, finish_request
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
assert_params_in_dict,
|
||||||
|
parse_boolean,
|
||||||
parse_bytes_from_args,
|
parse_bytes_from_args,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
@ -40,6 +44,21 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
LoginResponse = TypedDict(
|
||||||
|
"LoginResponse",
|
||||||
|
{
|
||||||
|
"user_id": str,
|
||||||
|
"access_token": str,
|
||||||
|
"home_server": str,
|
||||||
|
"expires_in_ms": Optional[int],
|
||||||
|
"refresh_token": Optional[str],
|
||||||
|
"device_id": str,
|
||||||
|
"well_known": Optional[Dict[str, Any]],
|
||||||
|
},
|
||||||
|
total=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoginRestServlet(RestServlet):
|
class LoginRestServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/login$", v1=True)
|
PATTERNS = client_patterns("/login$", v1=True)
|
||||||
CAS_TYPE = "m.login.cas"
|
CAS_TYPE = "m.login.cas"
|
||||||
@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
JWT_TYPE = "org.matrix.login.jwt"
|
JWT_TYPE = "org.matrix.login.jwt"
|
||||||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||||
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
||||||
|
REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
|
|||||||
self.cas_enabled = hs.config.cas_enabled
|
self.cas_enabled = hs.config.cas_enabled
|
||||||
self.oidc_enabled = hs.config.oidc_enabled
|
self.oidc_enabled = hs.config.oidc_enabled
|
||||||
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
|
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
|
||||||
|
self._msc2918_enabled = hs.config.access_token_lifetime is not None
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
|
|||||||
async def on_POST(self, request: SynapseRequest):
|
async def on_POST(self, request: SynapseRequest):
|
||||||
login_submission = parse_json_object_from_request(request)
|
login_submission = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
if self._msc2918_enabled:
|
||||||
|
# Check if this login should also issue a refresh token, as per
|
||||||
|
# MSC2918
|
||||||
|
should_issue_refresh_token = parse_boolean(
|
||||||
|
request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
should_issue_refresh_token = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||||
appservice = self.auth.get_appservice_by_req(request)
|
appservice = self.auth.get_appservice_by_req(request)
|
||||||
@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
|
|||||||
None, request.getClientIP()
|
None, request.getClientIP()
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self._do_appservice_login(login_submission, appservice)
|
result = await self._do_appservice_login(
|
||||||
|
login_submission,
|
||||||
|
appservice,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
)
|
||||||
elif self.jwt_enabled and (
|
elif self.jwt_enabled and (
|
||||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||||
):
|
):
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_jwt_login(login_submission)
|
result = await self._do_jwt_login(
|
||||||
|
login_submission,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_token_login(login_submission)
|
result = await self._do_token_login(
|
||||||
|
login_submission,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_other_login(login_submission)
|
result = await self._do_other_login(
|
||||||
|
login_submission,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
|
||||||
@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
|
|||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
async def _do_appservice_login(
|
async def _do_appservice_login(
|
||||||
self, login_submission: JsonDict, appservice: ApplicationService
|
self,
|
||||||
|
login_submission: JsonDict,
|
||||||
|
appservice: ApplicationService,
|
||||||
|
should_issue_refresh_token: bool = False,
|
||||||
):
|
):
|
||||||
identifier = login_submission.get("identifier")
|
identifier = login_submission.get("identifier")
|
||||||
logger.info("Got appservice login request with identifier: %r", identifier)
|
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||||
@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
|
|||||||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
return await self._complete_login(
|
return await self._complete_login(
|
||||||
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
|
qualified_user_id,
|
||||||
|
login_submission,
|
||||||
|
ratelimit=appservice.is_rate_limited(),
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_other_login(
|
||||||
|
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||||
|
) -> LoginResponse:
|
||||||
"""Handle non-token/saml/jwt logins
|
"""Handle non-token/saml/jwt logins
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
login_submission:
|
login_submission:
|
||||||
|
should_issue_refresh_token: True if this login should issue
|
||||||
|
a refresh token alongside the access token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HTTP response
|
HTTP response
|
||||||
@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
|
|||||||
login_submission, ratelimit=True
|
login_submission, ratelimit=True
|
||||||
)
|
)
|
||||||
result = await self._complete_login(
|
result = await self._complete_login(
|
||||||
canonical_user_id, login_submission, callback
|
canonical_user_id,
|
||||||
|
login_submission,
|
||||||
|
callback,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
login_submission: JsonDict,
|
login_submission: JsonDict,
|
||||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
|
||||||
create_non_existent_users: bool = False,
|
create_non_existent_users: bool = False,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
) -> Dict[str, str]:
|
should_issue_refresh_token: bool = False,
|
||||||
|
) -> LoginResponse:
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
actually login them in (e.g. create devices). This gets called on
|
actually login them in (e.g. create devices). This gets called on
|
||||||
all successful logins.
|
all successful logins.
|
||||||
@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
|
|||||||
ratelimit: Whether to ratelimit the login request.
|
ratelimit: Whether to ratelimit the login request.
|
||||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||||
prometheus metrics).
|
prometheus metrics).
|
||||||
|
should_issue_refresh_token: True if this login should issue
|
||||||
|
a refresh token alongside the access token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result: Dictionary of account information after successful login.
|
result: Dictionary of account information after successful login.
|
||||||
@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
|
|||||||
|
|
||||||
device_id = login_submission.get("device_id")
|
device_id = login_submission.get("device_id")
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
(
|
||||||
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
|
device_id,
|
||||||
|
access_token,
|
||||||
|
valid_until_ms,
|
||||||
|
refresh_token,
|
||||||
|
) = await self.registration_handler.register_device(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
initial_display_name,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = LoginResponse(
|
||||||
"user_id": user_id,
|
user_id=user_id,
|
||||||
"access_token": access_token,
|
access_token=access_token,
|
||||||
"home_server": self.hs.hostname,
|
home_server=self.hs.hostname,
|
||||||
"device_id": device_id,
|
device_id=device_id,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
if valid_until_ms is not None:
|
||||||
|
expires_in_ms = valid_until_ms - self.clock.time_msec()
|
||||||
|
result["expires_in_ms"] = expires_in_ms
|
||||||
|
|
||||||
|
if refresh_token is not None:
|
||||||
|
result["refresh_token"] = refresh_token
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
await callback(result)
|
await callback(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_token_login(
|
||||||
|
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||||
|
) -> LoginResponse:
|
||||||
"""
|
"""
|
||||||
Handle the final stage of SSO login.
|
Handle the final stage of SSO login.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
login_submission: The JSON request body.
|
login_submission: The JSON request body.
|
||||||
|
should_issue_refresh_token: True if this login should issue
|
||||||
|
a refresh token alongside the access token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The body of the JSON response.
|
The body of the JSON response.
|
||||||
@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
|
|||||||
login_submission,
|
login_submission,
|
||||||
self.auth_handler._sso_login_callback,
|
self.auth_handler._sso_login_callback,
|
||||||
auth_provider_id=res.auth_provider_id,
|
auth_provider_id=res.auth_provider_id,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_jwt_login(
|
||||||
|
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||||
|
) -> LoginResponse:
|
||||||
token = login_submission.get("token", None)
|
token = login_submission.get("token", None)
|
||||||
if token is None:
|
if token is None:
|
||||||
raise LoginError(
|
raise LoginError(
|
||||||
@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
|
|||||||
|
|
||||||
user_id = UserID(user, self.hs.hostname).to_string()
|
user_id = UserID(user, self.hs.hostname).to_string()
|
||||||
result = await self._complete_login(
|
result = await self._complete_login(
|
||||||
user_id, login_submission, create_non_existent_users=True
|
user_id,
|
||||||
|
login_submission,
|
||||||
|
create_non_existent_users=True,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
|
|||||||
return e
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenServlet(RestServlet):
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
"/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self.access_token_lifetime = hs.config.access_token_lifetime
|
||||||
|
|
||||||
|
async def on_POST(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
):
|
||||||
|
refresh_submission = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
assert_params_in_dict(refresh_submission, ["refresh_token"])
|
||||||
|
token = refresh_submission["refresh_token"]
|
||||||
|
if not isinstance(token, str):
|
||||||
|
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
|
||||||
|
access_token, refresh_token = await self._auth_handler.refresh_token(
|
||||||
|
token, valid_until_ms
|
||||||
|
)
|
||||||
|
expires_in_ms = valid_until_ms - self._clock.time_msec()
|
||||||
|
return (
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"expires_in_ms": expires_in_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SsoRedirectServlet(RestServlet):
|
class SsoRedirectServlet(RestServlet):
|
||||||
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
|
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
|
||||||
re.compile(
|
re.compile(
|
||||||
@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
|
|||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
LoginRestServlet(hs).register(http_server)
|
LoginRestServlet(hs).register(http_server)
|
||||||
|
if hs.config.access_token_lifetime is not None:
|
||||||
|
RefreshTokenServlet(hs).register(http_server)
|
||||||
SsoRedirectServlet(hs).register(http_server)
|
SsoRedirectServlet(hs).register(http_server)
|
||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
CasTicketServlet(hs).register(http_server)
|
CasTicketServlet(hs).register(http_server)
|
||||||
|
@ -41,11 +41,13 @@ from synapse.http.server import finish_request, respond_with_html
|
|||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
|
parse_boolean,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
)
|
)
|
||||||
from synapse.metrics import threepid_send_requests
|
from synapse.metrics import threepid_send_requests
|
||||||
from synapse.push.mailer import Mailer
|
from synapse.push.mailer import Mailer
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||||
@ -399,6 +401,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
self.password_policy_handler = hs.get_password_policy_handler()
|
self.password_policy_handler = hs.get_password_policy_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._registration_enabled = self.hs.config.enable_registration
|
self._registration_enabled = self.hs.config.enable_registration
|
||||||
|
self._msc2918_enabled = hs.config.access_token_lifetime is not None
|
||||||
|
|
||||||
self._registration_flows = _calculate_registration_flows(
|
self._registration_flows = _calculate_registration_flows(
|
||||||
hs.config, self.auth_handler
|
hs.config, self.auth_handler
|
||||||
@ -424,6 +427,15 @@ class RegisterRestServlet(RestServlet):
|
|||||||
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
|
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._msc2918_enabled:
|
||||||
|
# Check if this registration should also issue a refresh token, as
|
||||||
|
# per MSC2918
|
||||||
|
should_issue_refresh_token = parse_boolean(
|
||||||
|
request, name="org.matrix.msc2918.refresh_token", default=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
should_issue_refresh_token = False
|
||||||
|
|
||||||
# Pull out the provided username and do basic sanity checks early since
|
# Pull out the provided username and do basic sanity checks early since
|
||||||
# the auth layer will store these in sessions.
|
# the auth layer will store these in sessions.
|
||||||
desired_username = None
|
desired_username = None
|
||||||
@ -462,7 +474,10 @@ class RegisterRestServlet(RestServlet):
|
|||||||
raise SynapseError(400, "Desired Username is missing or not a string")
|
raise SynapseError(400, "Desired Username is missing or not a string")
|
||||||
|
|
||||||
result = await self._do_appservice_registration(
|
result = await self._do_appservice_registration(
|
||||||
desired_username, access_token, body
|
desired_username,
|
||||||
|
access_token,
|
||||||
|
body,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, result
|
return 200, result
|
||||||
@ -665,7 +680,9 @@ class RegisterRestServlet(RestServlet):
|
|||||||
registered = True
|
registered = True
|
||||||
|
|
||||||
return_dict = await self._create_registration_details(
|
return_dict = await self._create_registration_details(
|
||||||
registered_user_id, params
|
registered_user_id,
|
||||||
|
params,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if registered:
|
if registered:
|
||||||
@ -677,7 +694,9 @@ class RegisterRestServlet(RestServlet):
|
|||||||
|
|
||||||
return 200, return_dict
|
return 200, return_dict
|
||||||
|
|
||||||
async def _do_appservice_registration(self, username, as_token, body):
|
async def _do_appservice_registration(
|
||||||
|
self, username, as_token, body, should_issue_refresh_token: bool = False
|
||||||
|
):
|
||||||
user_id = await self.registration_handler.appservice_register(
|
user_id = await self.registration_handler.appservice_register(
|
||||||
username, as_token
|
username, as_token
|
||||||
)
|
)
|
||||||
@ -685,19 +704,27 @@ class RegisterRestServlet(RestServlet):
|
|||||||
user_id,
|
user_id,
|
||||||
body,
|
body,
|
||||||
is_appservice_ghost=True,
|
is_appservice_ghost=True,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_registration_details(
|
async def _create_registration_details(
|
||||||
self, user_id, params, is_appservice_ghost=False
|
self,
|
||||||
|
user_id: str,
|
||||||
|
params: JsonDict,
|
||||||
|
is_appservice_ghost: bool = False,
|
||||||
|
should_issue_refresh_token: bool = False,
|
||||||
):
|
):
|
||||||
"""Complete registration of newly-registered user
|
"""Complete registration of newly-registered user
|
||||||
|
|
||||||
Allocates device_id if one was not given; also creates access_token.
|
Allocates device_id if one was not given; also creates access_token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
(str) user_id: full canonical @user:id
|
user_id: full canonical @user:id
|
||||||
(object) params: registration parameters, from which we pull
|
params: registration parameters, from which we pull device_id,
|
||||||
device_id, initial_device_name and inhibit_login
|
initial_device_name and inhibit_login
|
||||||
|
is_appservice_ghost
|
||||||
|
should_issue_refresh_token: True if this registration should issue
|
||||||
|
a refresh token alongside the access token.
|
||||||
Returns:
|
Returns:
|
||||||
dictionary for response from /register
|
dictionary for response from /register
|
||||||
"""
|
"""
|
||||||
@ -705,15 +732,29 @@ class RegisterRestServlet(RestServlet):
|
|||||||
if not params.get("inhibit_login", False):
|
if not params.get("inhibit_login", False):
|
||||||
device_id = params.get("device_id")
|
device_id = params.get("device_id")
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
(
|
||||||
|
device_id,
|
||||||
|
access_token,
|
||||||
|
valid_until_ms,
|
||||||
|
refresh_token,
|
||||||
|
) = await self.registration_handler.register_device(
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
initial_display_name,
|
initial_display_name,
|
||||||
is_guest=False,
|
is_guest=False,
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
result.update({"access_token": access_token, "device_id": device_id})
|
result.update({"access_token": access_token, "device_id": device_id})
|
||||||
|
|
||||||
|
if valid_until_ms is not None:
|
||||||
|
expires_in_ms = valid_until_ms - self.clock.time_msec()
|
||||||
|
result["expires_in_ms"] = expires_in_ms
|
||||||
|
|
||||||
|
if refresh_token is not None:
|
||||||
|
result["refresh_token"] = refresh_token
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _do_guest_registration(self, params, address=None):
|
async def _do_guest_registration(self, params, address=None):
|
||||||
@ -727,19 +768,30 @@ class RegisterRestServlet(RestServlet):
|
|||||||
# we have nowhere to store it.
|
# we have nowhere to store it.
|
||||||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
(
|
||||||
|
device_id,
|
||||||
|
access_token,
|
||||||
|
valid_until_ms,
|
||||||
|
refresh_token,
|
||||||
|
) = await self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name, is_guest=True
|
user_id, device_id, initial_display_name, is_guest=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
result = {
|
||||||
200,
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
},
|
}
|
||||||
)
|
|
||||||
|
if valid_until_ms is not None:
|
||||||
|
expires_in_ms = valid_until_ms - self.clock.time_msec()
|
||||||
|
result["expires_in_ms"] = expires_in_ms
|
||||||
|
|
||||||
|
if refresh_token is not None:
|
||||||
|
result["refresh_token"] = refresh_token
|
||||||
|
|
||||||
|
return 200, result
|
||||||
|
|
||||||
|
|
||||||
def _calculate_registration_flows(
|
def _calculate_registration_flows(
|
||||||
|
@ -53,6 +53,9 @@ class TokenLookupResult:
|
|||||||
valid_until_ms: The timestamp the token expires, if any.
|
valid_until_ms: The timestamp the token expires, if any.
|
||||||
token_owner: The "owner" of the token. This is either the same as the
|
token_owner: The "owner" of the token. This is either the same as the
|
||||||
user, or a server admin who is logged in as the user.
|
user, or a server admin who is logged in as the user.
|
||||||
|
token_used: True if this token was used at least once in a request.
|
||||||
|
This field can be out of date since `get_user_by_access_token` is
|
||||||
|
cached.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_id = attr.ib(type=str)
|
user_id = attr.ib(type=str)
|
||||||
@ -62,6 +65,7 @@ class TokenLookupResult:
|
|||||||
device_id = attr.ib(type=Optional[str], default=None)
|
device_id = attr.ib(type=Optional[str], default=None)
|
||||||
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||||
token_owner = attr.ib(type=str)
|
token_owner = attr.ib(type=str)
|
||||||
|
token_used = attr.ib(type=bool, default=False)
|
||||||
|
|
||||||
# Make the token owner default to the user ID, which is the common case.
|
# Make the token owner default to the user ID, which is the common case.
|
||||||
@token_owner.default
|
@token_owner.default
|
||||||
@ -69,6 +73,29 @@ class TokenLookupResult:
|
|||||||
return self.user_id
|
return self.user_id
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True)
|
||||||
|
class RefreshTokenLookupResult:
|
||||||
|
"""Result of looking up a refresh token."""
|
||||||
|
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
"""The user this token belongs to."""
|
||||||
|
|
||||||
|
device_id = attr.ib(type=str)
|
||||||
|
"""The device associated with this refresh token."""
|
||||||
|
|
||||||
|
token_id = attr.ib(type=int)
|
||||||
|
"""The ID of this refresh token."""
|
||||||
|
|
||||||
|
next_token_id = attr.ib(type=Optional[int])
|
||||||
|
"""The ID of the refresh token which replaced this one."""
|
||||||
|
|
||||||
|
has_next_refresh_token_been_refreshed = attr.ib(type=bool)
|
||||||
|
"""True if the next refresh token was used for another refresh."""
|
||||||
|
|
||||||
|
has_next_access_token_been_used = attr.ib(type=bool)
|
||||||
|
"""True if the next access token was already used at least once."""
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
access_tokens.id as token_id,
|
access_tokens.id as token_id,
|
||||||
access_tokens.device_id,
|
access_tokens.device_id,
|
||||||
access_tokens.valid_until_ms,
|
access_tokens.valid_until_ms,
|
||||||
access_tokens.user_id as token_owner
|
access_tokens.user_id as token_owner,
|
||||||
|
access_tokens.used as token_used
|
||||||
FROM users
|
FROM users
|
||||||
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
||||||
WHERE token = ?
|
WHERE token = ?
|
||||||
@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
|
|
||||||
txn.execute(sql, (token,))
|
txn.execute(sql, (token,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
if rows:
|
if rows:
|
||||||
return TokenLookupResult(**rows[0])
|
row = rows[0]
|
||||||
|
|
||||||
|
# This field is nullable, ensure it comes out as a boolean
|
||||||
|
if row["token_used"] is None:
|
||||||
|
row["token_used"] = False
|
||||||
|
|
||||||
|
return TokenLookupResult(**row)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
desc="update_access_token_last_validated",
|
desc="update_access_token_last_validated",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Mark the access token as used, which invalidates the refresh token used
|
||||||
|
to obtain it.
|
||||||
|
|
||||||
|
Because get_user_by_access_token is cached, this function might be
|
||||||
|
called multiple times for the same token, effectively doing unnecessary
|
||||||
|
SQL updates. Because updating the `used` field only goes one way (from
|
||||||
|
False to True) it is safe to cache this function as well to avoid this
|
||||||
|
issue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id: The ID of the access token to update.
|
||||||
|
Raises:
|
||||||
|
StoreError if there was a problem updating this.
|
||||||
|
"""
|
||||||
|
await self.db_pool.simple_update_one(
|
||||||
|
"access_tokens",
|
||||||
|
{"id": token_id},
|
||||||
|
{"used": True},
|
||||||
|
desc="mark_access_token_as_used",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def lookup_refresh_token(
|
||||||
|
self, token: str
|
||||||
|
) -> Optional[RefreshTokenLookupResult]:
|
||||||
|
"""Lookup a refresh token with hints about its validity."""
|
||||||
|
|
||||||
|
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
|
||||||
|
txn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
rt.id token_id,
|
||||||
|
rt.user_id,
|
||||||
|
rt.device_id,
|
||||||
|
rt.next_token_id,
|
||||||
|
(nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
|
||||||
|
at.used has_next_access_token_been_used
|
||||||
|
FROM refresh_tokens rt
|
||||||
|
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
|
||||||
|
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
|
||||||
|
WHERE rt.token = ?
|
||||||
|
""",
|
||||||
|
(token,),
|
||||||
|
)
|
||||||
|
row = txn.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return RefreshTokenLookupResult(
|
||||||
|
token_id=row[0],
|
||||||
|
user_id=row[1],
|
||||||
|
device_id=row[2],
|
||||||
|
next_token_id=row[3],
|
||||||
|
has_next_refresh_token_been_refreshed=row[4],
|
||||||
|
# This column is nullable, ensure it's a boolean
|
||||||
|
has_next_access_token_been_used=(row[5] or False),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"lookup_refresh_token", _lookup_refresh_token_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the successor of a refresh token, removing the existing successor
|
||||||
|
if any.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id: ID of the refresh token to update.
|
||||||
|
next_token_id: ID of its successor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _replace_refresh_token_txn(txn) -> None:
|
||||||
|
# First check if there was an existing refresh token
|
||||||
|
old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
"refresh_tokens",
|
||||||
|
{"id": token_id},
|
||||||
|
"next_token_id",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
"refresh_tokens",
|
||||||
|
{"id": token_id},
|
||||||
|
{"next_token_id": next_token_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the old "next" token if it exists. This should cascade and
|
||||||
|
# delete the associated access_token
|
||||||
|
if old_next_token_id is not None:
|
||||||
|
self.db_pool.simple_delete_one_txn(
|
||||||
|
txn,
|
||||||
|
"refresh_tokens",
|
||||||
|
{"id": old_next_token_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"replace_refresh_token", _replace_refresh_token_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||||
|
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
|
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||||
|
|
||||||
async def add_access_token_to_user(
|
async def add_access_token_to_user(
|
||||||
self,
|
self,
|
||||||
@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
device_id: Optional[str],
|
device_id: Optional[str],
|
||||||
valid_until_ms: Optional[int],
|
valid_until_ms: Optional[int],
|
||||||
puppets_user_id: Optional[str] = None,
|
puppets_user_id: Optional[str] = None,
|
||||||
|
refresh_token_id: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Adds an access token for the given user.
|
"""Adds an access token for the given user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID.
|
user_id: The user ID.
|
||||||
token: The new access token to add.
|
token: The new access token to add.
|
||||||
device_id: ID of the device to associate with the access token
|
device_id: ID of the device to associate with the access token.
|
||||||
valid_until_ms: when the token is valid until. None for no expiry.
|
valid_until_ms: when the token is valid until. None for no expiry.
|
||||||
|
puppets_user_id
|
||||||
|
refresh_token_id: ID of the refresh token generated alongside this
|
||||||
|
access token.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem adding this.
|
StoreError if there was a problem adding this.
|
||||||
Returns:
|
Returns:
|
||||||
@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
"valid_until_ms": valid_until_ms,
|
"valid_until_ms": valid_until_ms,
|
||||||
"puppets_user_id": puppets_user_id,
|
"puppets_user_id": puppets_user_id,
|
||||||
"last_validated": now,
|
"last_validated": now,
|
||||||
|
"refresh_token_id": refresh_token_id,
|
||||||
|
"used": False,
|
||||||
},
|
},
|
||||||
desc="add_access_token_to_user",
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
|
async def add_refresh_token_to_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
token: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
) -> int:
|
||||||
|
"""Adds a refresh token for the given user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID.
|
||||||
|
token: The new access token to add.
|
||||||
|
device_id: ID of the device to associate with the refresh token.
|
||||||
|
Raises:
|
||||||
|
StoreError if there was a problem adding this.
|
||||||
|
Returns:
|
||||||
|
The token ID
|
||||||
|
"""
|
||||||
|
next_id = self._refresh_tokens_id_gen.get_next()
|
||||||
|
|
||||||
|
await self.db_pool.simple_insert(
|
||||||
|
"refresh_tokens",
|
||||||
|
{
|
||||||
|
"id": next_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"token": token,
|
||||||
|
"next_token_id": None,
|
||||||
|
},
|
||||||
|
desc="add_refresh_token_to_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
return next_id
|
||||||
|
|
||||||
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
||||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn, "access_tokens", {"token": token}, "device_id"
|
txn, "access_tokens", {"token": token}, "device_id"
|
||||||
@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
device_id: Optional[str] = None,
|
device_id: Optional[str] = None,
|
||||||
) -> List[Tuple[str, int, Optional[str]]]:
|
) -> List[Tuple[str, int, Optional[str]]]:
|
||||||
"""
|
"""
|
||||||
Invalidate access tokens belonging to a user
|
Invalidate access and refresh tokens belonging to a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: ID of user the tokens belong to
|
user_id: ID of user the tokens belong to
|
||||||
@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
items = keyvalues.items()
|
items = keyvalues.items()
|
||||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||||
values = [v for _, v in items] # type: List[Union[str, int]]
|
values = [v for _, v in items] # type: List[Union[str, int]]
|
||||||
|
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
||||||
|
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
||||||
|
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
||||||
|
refresh_where_clause = where_clause
|
||||||
|
refresh_values = values.copy()
|
||||||
if except_token_id:
|
if except_token_id:
|
||||||
|
# TODO: support that for refresh tokens
|
||||||
where_clause += " AND id != ?"
|
where_clause += " AND id != ?"
|
||||||
values.append(except_token_id)
|
values.append(except_token_id)
|
||||||
|
|
||||||
@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
|
|
||||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
|
||||||
|
refresh_values,
|
||||||
|
)
|
||||||
|
|
||||||
return tokens_and_devices
|
return tokens_and_devices
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||||
@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||||||
|
|
||||||
await self.db_pool.runInteraction("delete_access_token", f)
|
await self.db_pool.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
|
async def delete_refresh_token(self, refresh_token: str) -> None:
|
||||||
|
def f(txn):
|
||||||
|
self.db_pool.simple_delete_one_txn(
|
||||||
|
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction("delete_refresh_token", f)
|
||||||
|
|
||||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||||
|
34
synapse/storage/schema/main/delta/59/14refresh_tokens.sql
Normal file
34
synapse/storage/schema/main/delta/59/14refresh_tokens.sql
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Holds MSC2918 refresh tokens
|
||||||
|
CREATE TABLE refresh_tokens (
|
||||||
|
id BIGINT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
-- When consumed, a new refresh token is generated, which is tracked by
|
||||||
|
-- this foreign key
|
||||||
|
next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(token)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Add a reference to the refresh token generated alongside each access token
|
||||||
|
ALTER TABLE "access_tokens"
|
||||||
|
ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE;
|
||||||
|
|
||||||
|
-- Add a flag whether the token was already used or not
|
||||||
|
ALTER TABLE "access_tokens"
|
||||||
|
ADD COLUMN used BOOLEAN;
|
@ -58,6 +58,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
user_id=self.test_user, token_id=5, device_id="device"
|
user_id=self.test_user, token_id=5, device_id="device"
|
||||||
)
|
)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
||||||
|
self.store.mark_access_token_as_used = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -257,7 +257,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
|
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
|
||||||
|
|
||||||
# Create a new login for the user and dehydrated the device
|
# Create a new login for the user and dehydrated the device
|
||||||
device_id, access_token = self.get_success(
|
device_id, access_token, _expiration_time, _refresh_token = self.get_success(
|
||||||
self.registration.register_device(
|
self.registration.register_device(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
device_id=None,
|
device_id=None,
|
||||||
|
@ -20,7 +20,7 @@ import synapse.rest.admin
|
|||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
from synapse.rest.client.v1 import login
|
from synapse.rest.client.v1 import login
|
||||||
from synapse.rest.client.v2_alpha import auth, devices, register
|
from synapse.rest.client.v2_alpha import account, auth, devices, register
|
||||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
@ -498,3 +498,221 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
self.delete_device(
|
self.delete_device(
|
||||||
self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
|
self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
auth.register_servlets,
|
||||||
|
account.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
register.register_servlets,
|
||||||
|
]
|
||||||
|
hijack_auth = False
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.user_pass = "pass"
|
||||||
|
self.user = self.register_user("test", self.user_pass)
|
||||||
|
|
||||||
|
def test_login_issue_refresh_token(self):
|
||||||
|
"""
|
||||||
|
A login response should include a refresh_token only if asked.
|
||||||
|
"""
|
||||||
|
# Test login
|
||||||
|
body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
|
||||||
|
|
||||||
|
login_without_refresh = self.make_request(
|
||||||
|
"POST", "/_matrix/client/r0/login", body
|
||||||
|
)
|
||||||
|
self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result)
|
||||||
|
self.assertNotIn("refresh_token", login_without_refresh.json_body)
|
||||||
|
|
||||||
|
login_with_refresh = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
|
||||||
|
self.assertIn("refresh_token", login_with_refresh.json_body)
|
||||||
|
self.assertIn("expires_in_ms", login_with_refresh.json_body)
|
||||||
|
|
||||||
|
def test_register_issue_refresh_token(self):
|
||||||
|
"""
|
||||||
|
A register response should include a refresh_token only if asked.
|
||||||
|
"""
|
||||||
|
register_without_refresh = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/register",
|
||||||
|
{
|
||||||
|
"username": "test2",
|
||||||
|
"password": self.user_pass,
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
register_without_refresh.code, 200, register_without_refresh.result
|
||||||
|
)
|
||||||
|
self.assertNotIn("refresh_token", register_without_refresh.json_body)
|
||||||
|
|
||||||
|
register_with_refresh = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true",
|
||||||
|
{
|
||||||
|
"username": "test3",
|
||||||
|
"password": self.user_pass,
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
|
||||||
|
self.assertIn("refresh_token", register_with_refresh.json_body)
|
||||||
|
self.assertIn("expires_in_ms", register_with_refresh.json_body)
|
||||||
|
|
||||||
|
def test_token_refresh(self):
|
||||||
|
"""
|
||||||
|
A refresh token can be used to issue a new access token.
|
||||||
|
"""
|
||||||
|
body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
|
||||||
|
login_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
self.assertEqual(login_response.code, 200, login_response.result)
|
||||||
|
|
||||||
|
refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(refresh_response.code, 200, refresh_response.result)
|
||||||
|
self.assertIn("access_token", refresh_response.json_body)
|
||||||
|
self.assertIn("refresh_token", refresh_response.json_body)
|
||||||
|
self.assertIn("expires_in_ms", refresh_response.json_body)
|
||||||
|
|
||||||
|
# The access and refresh tokens should be different from the original ones after refresh
|
||||||
|
self.assertNotEqual(
|
||||||
|
login_response.json_body["access_token"],
|
||||||
|
refresh_response.json_body["access_token"],
|
||||||
|
)
|
||||||
|
self.assertNotEqual(
|
||||||
|
login_response.json_body["refresh_token"],
|
||||||
|
refresh_response.json_body["refresh_token"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config({"access_token_lifetime": "1m"})
|
||||||
|
def test_refresh_token_expiration(self):
|
||||||
|
"""
|
||||||
|
The access token should have some time as specified in the config.
|
||||||
|
"""
|
||||||
|
body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
|
||||||
|
login_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
self.assertEqual(login_response.code, 200, login_response.result)
|
||||||
|
self.assertApproximates(
|
||||||
|
login_response.json_body["expires_in_ms"], 60 * 1000, 100
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(refresh_response.code, 200, refresh_response.result)
|
||||||
|
self.assertApproximates(
|
||||||
|
refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_refresh_token_invalidation(self):
|
||||||
|
"""Refresh tokens are invalidated after first use of the next token.
|
||||||
|
|
||||||
|
A refresh token is considered invalid if:
|
||||||
|
- it was already used at least once
|
||||||
|
- and either
|
||||||
|
- the next access token was used
|
||||||
|
- the next refresh token was used
|
||||||
|
|
||||||
|
The chain of tokens goes like this:
|
||||||
|
|
||||||
|
login -|-> first_refresh -> third_refresh (fails)
|
||||||
|
|-> second_refresh -> fifth_refresh
|
||||||
|
|-> fourth_refresh (fails)
|
||||||
|
"""
|
||||||
|
|
||||||
|
body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
|
||||||
|
login_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
self.assertEqual(login_response.code, 200, login_response.result)
|
||||||
|
|
||||||
|
# This first refresh should work properly
|
||||||
|
first_refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
first_refresh_response.code, 200, first_refresh_response.result
|
||||||
|
)
|
||||||
|
|
||||||
|
# This one as well, since the token in the first one was never used
|
||||||
|
second_refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
second_refresh_response.code, 200, second_refresh_response.result
|
||||||
|
)
|
||||||
|
|
||||||
|
# This one should not, since the token from the first refresh is not valid anymore
|
||||||
|
third_refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": first_refresh_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
third_refresh_response.code, 401, third_refresh_response.result
|
||||||
|
)
|
||||||
|
|
||||||
|
# The associated access token should also be invalid
|
||||||
|
whoami_response = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/client/r0/account/whoami",
|
||||||
|
access_token=first_refresh_response.json_body["access_token"],
|
||||||
|
)
|
||||||
|
self.assertEqual(whoami_response.code, 401, whoami_response.result)
|
||||||
|
|
||||||
|
# But all other tokens should work (they will expire after some time)
|
||||||
|
for access_token in [
|
||||||
|
second_refresh_response.json_body["access_token"],
|
||||||
|
login_response.json_body["access_token"],
|
||||||
|
]:
|
||||||
|
whoami_response = self.make_request(
|
||||||
|
"GET", "/_matrix/client/r0/account/whoami", access_token=access_token
|
||||||
|
)
|
||||||
|
self.assertEqual(whoami_response.code, 200, whoami_response.result)
|
||||||
|
|
||||||
|
# Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
|
||||||
|
fourth_refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
fourth_refresh_response.code, 403, fourth_refresh_response.result
|
||||||
|
)
|
||||||
|
|
||||||
|
# But refreshing from the last valid refresh token still works
|
||||||
|
fifth_refresh_response = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
|
||||||
|
{"refresh_token": second_refresh_response.json_body["refresh_token"]},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
fifth_refresh_response.code, 200, fifth_refresh_response.result
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user