mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-05-08 06:54:56 -04:00
Persist user interactive authentication sessions (#7302)
By persisting the user interactive authentication sessions to the database, this fixes situations where a user hits different works throughout their auth session and also allows sessions to persist through restarts of Synapse.
This commit is contained in:
parent
9d8ecc9e6c
commit
627b0f5f27
14 changed files with 434 additions and 125 deletions
|
@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
|||
from synapse.http.server import finish_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||
|
||||
# This is not a cache per se, but a store of all current sessions that
|
||||
# expire after N hours
|
||||
self.sessions = ExpiringCache(
|
||||
cache_name="register_sessions",
|
||||
clock=hs.get_clock(),
|
||||
expiry_ms=self.SESSION_EXPIRE_MS,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
account_handler = ModuleApi(hs, self)
|
||||
self.password_providers = [
|
||||
module(config=config, account_handler=account_handler)
|
||||
|
@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
# Expire old UI auth sessions after a period of time.
|
||||
if hs.config.worker_app is None:
|
||||
self._clock.looping_call(
|
||||
run_as_background_process,
|
||||
5 * 60 * 1000,
|
||||
"expire_old_sessions",
|
||||
self._expire_old_sessions,
|
||||
)
|
||||
|
||||
# Load the SSO HTML templates.
|
||||
|
||||
# The following template is shown to the user during a client login via SSO,
|
||||
|
@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
|
|||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
|
||||
# Convert the URI and method to strings.
|
||||
uri = request.uri.decode("utf-8")
|
||||
method = request.uri.decode("utf-8")
|
||||
|
||||
# If there's no session ID, create a new session.
|
||||
if not sid:
|
||||
session = self._create_session(
|
||||
clientdict, (request.uri, request.method, clientdict), description
|
||||
session = await self.store.create_ui_auth_session(
|
||||
clientdict, uri, method, description
|
||||
)
|
||||
session_id = session["id"]
|
||||
|
||||
else:
|
||||
session = self._get_session_info(sid)
|
||||
session_id = sid
|
||||
try:
|
||||
session = await self.store.get_ui_auth_session(sid)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (sid,))
|
||||
|
||||
if not clientdict:
|
||||
# This was designed to allow the client to omit the parameters
|
||||
|
@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
|
|||
# on a homeserver.
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbitrary.
|
||||
clientdict = session["clientdict"]
|
||||
clientdict = session.clientdict
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if session["ui_auth"] != comparator:
|
||||
comparator = (uri, method, clientdict)
|
||||
if (session.uri, session.method, session.clientdict) != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
|
@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session_id)
|
||||
self._auth_dict_for_flows(flows, session.session_id)
|
||||
)
|
||||
|
||||
creds = session["creds"]
|
||||
|
||||
# check auth type currently being presented
|
||||
errordict = {} # type: Dict[str, Any]
|
||||
if "type" in authdict:
|
||||
|
@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
|
|||
try:
|
||||
result = await self._check_auth_dict(authdict, clientip)
|
||||
if result:
|
||||
creds[login_type] = result
|
||||
self._save_session(session)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session.session_id, login_type, result
|
||||
)
|
||||
except LoginError as e:
|
||||
if login_type == LoginType.EMAIL_IDENTITY:
|
||||
# riot used to have a bug where it would request a new
|
||||
|
@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
|
|||
# so that the client can have another go.
|
||||
errordict = e.error_dict()
|
||||
|
||||
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
|
||||
for f in flows:
|
||||
if len(set(f) - set(creds)) == 0:
|
||||
# it's very useful to know what args are stored, but this can
|
||||
|
@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
|
|||
list(clientdict),
|
||||
)
|
||||
|
||||
return creds, clientdict, session_id
|
||||
return creds, clientdict, session.session_id
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session_id)
|
||||
ret = self._auth_dict_for_flows(flows, session.session_id)
|
||||
ret["completed"] = list(creds)
|
||||
ret.update(errordict)
|
||||
raise InteractiveAuthIncompleteError(ret)
|
||||
|
@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
|
|||
if "session" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
sess = self._get_session_info(authdict["session"])
|
||||
creds = sess["creds"]
|
||||
|
||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
if result:
|
||||
creds[stagetype] = result
|
||||
self._save_session(sess)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
authdict["session"], stagetype, result
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
|
|||
sid = authdict["session"]
|
||||
return sid
|
||||
|
||||
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
|
@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
|
|||
key: The key to store the data under
|
||||
value: The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess["serverdict"][key] = value
|
||||
self._save_session(sess)
|
||||
try:
|
||||
await self.store.set_ui_auth_session_data(session_id, key, value)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
def get_session_data(
|
||||
async def get_session_data(
|
||||
self, session_id: str, key: str, default: Optional[Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
|
@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
|
|||
key: The key to store the data under
|
||||
default: Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess["serverdict"].get(key, default)
|
||||
try:
|
||||
return await self.store.get_ui_auth_session_data(session_id, key, default)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def _expire_old_sessions(self):
|
||||
"""
|
||||
Invalidate any user interactive authentication sessions that have expired.
|
||||
"""
|
||||
now = self._clock.time_msec()
|
||||
expiration_time = now - self.SESSION_EXPIRE_MS
|
||||
await self.store.delete_old_ui_auth_sessions(expiration_time)
|
||||
|
||||
async def _check_auth_dict(
|
||||
self, authdict: Dict[str, Any], clientip: str
|
||||
|
@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
|
|||
"params": params,
|
||||
}
|
||||
|
||||
def _create_session(
|
||||
self,
|
||||
clientdict: Dict[str, Any],
|
||||
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
|
||||
description: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
|
||||
Each session has the following keys:
|
||||
|
||||
id:
|
||||
A unique identifier for this session. Passed back to the client
|
||||
and returned for each stage.
|
||||
clientdict:
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
ui_auth:
|
||||
A tuple which is checked at each stage of the authentication to
|
||||
ensure that the asked for operation has not changed.
|
||||
creds:
|
||||
A map, which maps each auth-type (str) to the relevant identity
|
||||
authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||
serverdict:
|
||||
A map of data that is stored server-side and cannot be modified
|
||||
by the client.
|
||||
description:
|
||||
A string description of the operation that the current
|
||||
authentication is authorising.
|
||||
Returns:
|
||||
The newly created session.
|
||||
"""
|
||||
session_id = None
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
|
||||
self.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
"clientdict": clientdict,
|
||||
"ui_auth": ui_auth,
|
||||
"creds": {},
|
||||
"serverdict": {},
|
||||
"description": description,
|
||||
}
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def _get_session_info(self, session_id: str) -> dict:
|
||||
"""
|
||||
Gets a session given a session ID.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
"""
|
||||
try:
|
||||
return self.sessions[session_id]
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def get_access_token_for_user_id(
|
||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||
):
|
||||
|
@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
|
|||
await self.store.user_delete_threepid(user_id, medium, address)
|
||||
return result
|
||||
|
||||
def _save_session(self, session: Dict[str, Any]) -> None:
|
||||
"""Update the last used time on the session to now and add it back to the session store."""
|
||||
# TODO: Persistent storage
|
||||
logger.debug("Saving session %s", session)
|
||||
session["last_used"] = self.hs.get_clock().time_msec()
|
||||
self.sessions[session["id"]] = session
|
||||
|
||||
async def hash(self, password: str) -> str:
|
||||
"""Computes a secure hash of password.
|
||||
|
||||
|
@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
return False
|
||||
|
||||
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
"""
|
||||
Get the HTML for the SSO redirect confirmation page.
|
||||
|
||||
|
@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
The HTML to render.
|
||||
"""
|
||||
session = self._get_session_info(session_id)
|
||||
try:
|
||||
session = await self.store.get_ui_auth_session(session_id)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
return self._sso_auth_confirm_template.render(
|
||||
description=session["description"], redirect_url=redirect_url,
|
||||
description=session.description, redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
def complete_sso_ui_auth(
|
||||
async def complete_sso_ui_auth(
|
||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
|
|||
process.
|
||||
"""
|
||||
# Mark the stage of the authentication as successful.
|
||||
sess = self._get_session_info(session_id)
|
||||
creds = sess["creds"]
|
||||
|
||||
# Save the user who authenticated with SSO, this will be used to ensure
|
||||
# that the account be modified is also the person who logged in.
|
||||
creds[LoginType.SSO] = registered_user_id
|
||||
self._save_session(sess)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session_id, LoginType.SSO, registered_user_id
|
||||
)
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = self._sso_auth_success_template.encode("utf-8")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue