MSC2918 Refresh tokens implementation (#9450)

This implements refresh tokens, as defined by MSC2918

This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235

The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one.

Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
This commit is contained in:
Quentin Gliech 2021-06-24 15:33:20 +02:00 committed by GitHub
parent 763dba77ef
commit bd4919fb72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 892 additions and 61 deletions

View file

@ -14,7 +14,9 @@
import logging
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from typing_extensions import TypedDict
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@ -25,6 +27,8 @@ from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
parse_bytes_from_args,
parse_json_object_from_request,
parse_string,
@ -40,6 +44,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
LoginResponse = TypedDict(
"LoginResponse",
{
"user_id": str,
"access_token": str,
"home_server": str,
"expires_in_ms": Optional[int],
"refresh_token": Optional[str],
"device_id": str,
"well_known": Optional[Dict[str, Any]],
},
total=False,
)
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._msc2918_enabled = hs.config.access_token_lifetime is not None
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest):
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
# Check if this login should also issue a refresh token, as per
# MSC2918
should_issue_refresh_token = parse_boolean(
request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
)
else:
should_issue_refresh_token = False
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
None, request.getClientIP()
)
result = await self._do_appservice_login(login_submission, appservice)
result = await self._do_appservice_login(
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_jwt_login(login_submission)
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_token_login(login_submission)
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
else:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_other_login(login_submission)
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
return 200, result
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
self,
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
):
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
return await self._complete_login(
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
qualified_user_id,
login_submission,
ratelimit=appservice.is_rate_limited(),
should_issue_refresh_token=should_issue_refresh_token,
)
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
async def _do_other_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token.
Returns:
HTTP response
@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
login_submission, ratelimit=True
)
result = await self._complete_login(
canonical_user_id, login_submission, callback
canonical_user_id,
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
)
return result
@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
should_issue_refresh_token: bool = False,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
all successful logins.
@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
ratelimit: Whether to ratelimit the login request.
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token.
Returns:
result: Dictionary of account information after successful login.
@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
(
device_id,
access_token,
valid_until_ms,
refresh_token,
) = await self.registration_handler.register_device(
user_id,
device_id,
initial_display_name,
auth_provider_id=auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
)
result = {
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
result = LoginResponse(
user_id=user_id,
access_token=access_token,
home_server=self.hs.hostname,
device_id=device_id,
)
if valid_until_ms is not None:
expires_in_ms = valid_until_ms - self.clock.time_msec()
result["expires_in_ms"] = expires_in_ms
if refresh_token is not None:
result["refresh_token"] = refresh_token
if callback is not None:
await callback(result)
return result
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
async def _do_token_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
"""
Handle the final stage of SSO login.
Args:
login_submission: The JSON request body.
login_submission: The JSON request body.
should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token.
Returns:
The body of the JSON response.
@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
login_submission,
self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
user_id, login_submission, create_non_existent_users=True
user_id,
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
)
return result
@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
return e
class RefreshTokenServlet(RestServlet):
PATTERNS = client_patterns(
"/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
)
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime
async def on_POST(
self,
request: SynapseRequest,
):
refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"])
token = refresh_submission["refresh_token"]
if not isinstance(token, str):
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
access_token, refresh_token = await self._auth_handler.refresh_token(
token, valid_until_ms
)
expires_in_ms = valid_until_ms - self._clock.time_msec()
return (
200,
{
"access_token": access_token,
"refresh_token": refresh_token,
"expires_in_ms": expires_in_ms,
},
)
class SsoRedirectServlet(RestServlet):
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
re.compile(
@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)