mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-19 09:04:06 -04:00
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
This commit is contained in:
parent
763dba77ef
commit
bd4919fb72
15 changed files with 892 additions and 61 deletions
|
@ -14,7 +14,9 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
|
@ -25,6 +27,8 @@ from synapse.http import get_request_uri
|
|||
from synapse.http.server import HttpServer, finish_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_boolean,
|
||||
parse_bytes_from_args,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
|
@ -40,6 +44,21 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LoginResponse = TypedDict(
|
||||
"LoginResponse",
|
||||
{
|
||||
"user_id": str,
|
||||
"access_token": str,
|
||||
"home_server": str,
|
||||
"expires_in_ms": Optional[int],
|
||||
"refresh_token": Optional[str],
|
||||
"device_id": str,
|
||||
"well_known": Optional[Dict[str, Any]],
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login$", v1=True)
|
||||
CAS_TYPE = "m.login.cas"
|
||||
|
@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
|
|||
JWT_TYPE = "org.matrix.login.jwt"
|
||||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
||||
REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
|
|||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.oidc_enabled = hs.config.oidc_enabled
|
||||
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
|
||||
self._msc2918_enabled = hs.config.access_token_lifetime is not None
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
|
|||
async def on_POST(self, request: SynapseRequest):
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
|
||||
if self._msc2918_enabled:
|
||||
# Check if this login should also issue a refresh token, as per
|
||||
# MSC2918
|
||||
should_issue_refresh_token = parse_boolean(
|
||||
request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
|
||||
)
|
||||
else:
|
||||
should_issue_refresh_token = False
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
appservice = self.auth.get_appservice_by_req(request)
|
||||
|
@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
|
|||
None, request.getClientIP()
|
||||
)
|
||||
|
||||
result = await self._do_appservice_login(login_submission, appservice)
|
||||
result = await self._do_appservice_login(
|
||||
login_submission,
|
||||
appservice,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
elif self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||
):
|
||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||
result = await self._do_jwt_login(login_submission)
|
||||
result = await self._do_jwt_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||
result = await self._do_token_login(login_submission)
|
||||
result = await self._do_token_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
else:
|
||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||
result = await self._do_other_login(login_submission)
|
||||
result = await self._do_other_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
||||
|
@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
|
|||
return 200, result
|
||||
|
||||
async def _do_appservice_login(
|
||||
self, login_submission: JsonDict, appservice: ApplicationService
|
||||
self,
|
||||
login_submission: JsonDict,
|
||||
appservice: ApplicationService,
|
||||
should_issue_refresh_token: bool = False,
|
||||
):
|
||||
identifier = login_submission.get("identifier")
|
||||
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||
|
@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
|
|||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
return await self._complete_login(
|
||||
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
|
||||
qualified_user_id,
|
||||
login_submission,
|
||||
ratelimit=appservice.is_rate_limited(),
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
|
||||
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
async def _do_other_login(
|
||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||
) -> LoginResponse:
|
||||
"""Handle non-token/saml/jwt logins
|
||||
|
||||
Args:
|
||||
login_submission:
|
||||
should_issue_refresh_token: True if this login should issue
|
||||
a refresh token alongside the access token.
|
||||
|
||||
Returns:
|
||||
HTTP response
|
||||
|
@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission, ratelimit=True
|
||||
)
|
||||
result = await self._complete_login(
|
||||
canonical_user_id, login_submission, callback
|
||||
canonical_user_id,
|
||||
login_submission,
|
||||
callback,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
|
|||
self,
|
||||
user_id: str,
|
||||
login_submission: JsonDict,
|
||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
|
||||
create_non_existent_users: bool = False,
|
||||
ratelimit: bool = True,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
should_issue_refresh_token: bool = False,
|
||||
) -> LoginResponse:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
all successful logins.
|
||||
|
@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
|
|||
ratelimit: Whether to ratelimit the login request.
|
||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||
prometheus metrics).
|
||||
should_issue_refresh_token: True if this login should issue
|
||||
a refresh token alongside the access token.
|
||||
|
||||
Returns:
|
||||
result: Dictionary of account information after successful login.
|
||||
|
@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = await self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
|
||||
(
|
||||
device_id,
|
||||
access_token,
|
||||
valid_until_ms,
|
||||
refresh_token,
|
||||
) = await self.registration_handler.register_device(
|
||||
user_id,
|
||||
device_id,
|
||||
initial_display_name,
|
||||
auth_provider_id=auth_provider_id,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
result = LoginResponse(
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
home_server=self.hs.hostname,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if valid_until_ms is not None:
|
||||
expires_in_ms = valid_until_ms - self.clock.time_msec()
|
||||
result["expires_in_ms"] = expires_in_ms
|
||||
|
||||
if refresh_token is not None:
|
||||
result["refresh_token"] = refresh_token
|
||||
|
||||
if callback is not None:
|
||||
await callback(result)
|
||||
|
||||
return result
|
||||
|
||||
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
async def _do_token_login(
|
||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||
) -> LoginResponse:
|
||||
"""
|
||||
Handle the final stage of SSO login.
|
||||
|
||||
Args:
|
||||
login_submission: The JSON request body.
|
||||
login_submission: The JSON request body.
|
||||
should_issue_refresh_token: True if this login should issue
|
||||
a refresh token alongside the access token.
|
||||
|
||||
Returns:
|
||||
The body of the JSON response.
|
||||
|
@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission,
|
||||
self.auth_handler._sso_login_callback,
|
||||
auth_provider_id=res.auth_provider_id,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
|
||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
async def _do_jwt_login(
|
||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||
) -> LoginResponse:
|
||||
token = login_submission.get("token", None)
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
|
@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
user_id = UserID(user, self.hs.hostname).to_string()
|
||||
result = await self._complete_login(
|
||||
user_id, login_submission, create_non_existent_users=True
|
||||
user_id,
|
||||
login_submission,
|
||||
create_non_existent_users=True,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
|
|||
return e
|
||||
|
||||
|
||||
class RefreshTokenServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._clock = hs.get_clock()
|
||||
self.access_token_lifetime = hs.config.access_token_lifetime
|
||||
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
):
|
||||
refresh_submission = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(refresh_submission, ["refresh_token"])
|
||||
token = refresh_submission["refresh_token"]
|
||||
if not isinstance(token, str):
|
||||
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
|
||||
|
||||
valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
|
||||
access_token, refresh_token = await self._auth_handler.refresh_token(
|
||||
token, valid_until_ms
|
||||
)
|
||||
expires_in_ms = valid_until_ms - self._clock.time_msec()
|
||||
return (
|
||||
200,
|
||||
{
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"expires_in_ms": expires_in_ms,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SsoRedirectServlet(RestServlet):
|
||||
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
|
||||
re.compile(
|
||||
|
@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
|
|||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.access_token_lifetime is not None:
|
||||
RefreshTokenServlet(hs).register(http_server)
|
||||
SsoRedirectServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue