Support expiry of refresh tokens and expiry of the overall session when refresh tokens are in use. (#11425)

This commit is contained in:
reivilibre 2021-11-26 14:27:14 +00:00 committed by GitHub
parent e2c300e7e4
commit 1d8b80b334
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 338 additions and 54 deletions

View file

@ -14,7 +14,17 @@
import logging
import re
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)
from typing_extensions import TypedDict
@ -458,6 +468,7 @@ class RefreshTokenServlet(RestServlet):
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
@ -467,22 +478,33 @@ class RefreshTokenServlet(RestServlet):
if not isinstance(token, str):
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
valid_until_ms = (
self._clock.time_msec() + self.refreshable_access_token_lifetime
)
access_token, refresh_token = await self._auth_handler.refresh_token(
token, valid_until_ms
)
expires_in_ms = valid_until_ms - self._clock.time_msec()
return (
200,
{
"access_token": access_token,
"refresh_token": refresh_token,
"expires_in_ms": expires_in_ms,
},
now = self._clock.time_msec()
access_valid_until_ms = None
if self.refreshable_access_token_lifetime is not None:
access_valid_until_ms = now + self.refreshable_access_token_lifetime
refresh_valid_until_ms = None
if self.refresh_token_lifetime is not None:
refresh_valid_until_ms = now + self.refresh_token_lifetime
(
access_token,
refresh_token,
actual_access_token_expiry,
) = await self._auth_handler.refresh_token(
token, access_valid_until_ms, refresh_valid_until_ms
)
response: Dict[str, Union[str, int]] = {
"access_token": access_token,
"refresh_token": refresh_token,
}
# expires_in_ms is only present if the token expires
if actual_access_token_expiry is not None:
response["expires_in_ms"] = actual_access_token_expiry - now
return 200, response
class SsoRedirectServlet(RestServlet):
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [