# Copyright 2020 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import LoginType from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict from synapse.util import json_encoder, stringutils @attr.s(slots=True, auto_attribs=True) class UIAuthSessionData: session_id: str # The dictionary from the client root level, not the 'auth' key. clientdict: JsonDict # The URI and method the session was intiatied with. These are checked at # each stage of the authentication to ensure that the asked for operation # has not changed. uri: str method: str # A string description of the operation that the current authentication is # authorising. description: str class UIAuthWorkerStore(SQLBaseStore): """ Manage user interactive authentication sessions. """ async def create_ui_auth_session( self, clientdict: JsonDict, uri: str, method: str, description: str, ) -> UIAuthSessionData: """ Creates a new user interactive authentication session. The session can be used to track the stages necessary to authenticate a user across multiple HTTP requests. Args: clientdict: The dictionary from the client root level, not the 'auth' key. uri: The URI this session was initiated with, this is checked at each stage of the authentication to ensure that the asked for operation has not changed. method: The method this session was initiated with, this is checked at each stage of the authentication to ensure that the asked for operation has not changed. description: A string description of the operation that the current authentication is authorising. Returns: The newly created session. Raises: StoreError if a unique session ID cannot be generated. """ # The clientdict gets stored as JSON. clientdict_json = json_encoder.encode(clientdict) # autogen a session ID and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. attempts = 0 while attempts < 5: session_id = stringutils.random_string(24) try: await self.db_pool.simple_insert( table="ui_auth_sessions", values={ "session_id": session_id, "clientdict": clientdict_json, "uri": uri, "method": method, "description": description, "serverdict": "{}", "creation_time": self.hs.get_clock().time_msec(), }, desc="create_ui_auth_session", ) return UIAuthSessionData( session_id, clientdict, uri, method, description ) except self.db_pool.engine.module.IntegrityError: attempts += 1 raise StoreError(500, "Couldn't generate a session ID.") async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: """Retrieve a UI auth session. Args: session_id: The ID of the session. Returns: A dict containing the device information. Raises: StoreError if the session is not found. """ result = await self.db_pool.simple_select_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("clientdict", "uri", "method", "description"), desc="get_ui_auth_session", ) result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(session_id, **result) async def mark_ui_auth_stage_complete( self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], ) -> None: """ Mark a session stage as completed. Args: session_id: The ID of the corresponding session. stage_type: The completed stage type. result: The result of the stage verification. Raises: StoreError if the session cannot be found. """ # Add (or update) the results of the current stage to the database. # # Note that we need to allow for the same stage to complete multiple # times here so that registration is idempotent. try: await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, values={"result": json_encoder.encode(result)}, desc="mark_ui_auth_stage_complete", ) except self.db_pool.engine.module.IntegrityError: raise StoreError(400, "Unknown session ID: %s" % (session_id,)) async def get_completed_ui_auth_stages( self, session_id: str ) -> Dict[str, Union[str, bool, JsonDict]]: """ Retrieve the completed stages of a UI authentication session. Args: session_id: The ID of the session. Returns: The completed stages mapped to the result of the verification of that auth-type. """ results = {} for row in await self.db_pool.simple_select_list( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id}, retcols=("stage_type", "result"), desc="get_completed_ui_auth_stages", ): results[row["stage_type"]] = db_to_json(row["result"]) return results async def set_ui_auth_clientdict( self, session_id: str, clientdict: JsonDict ) -> None: """ Store an updated clientdict for a given session ID. Args: session_id: The ID of this session as returned from check_auth clientdict: The dictionary from the client root level, not the 'auth' key. """ # The clientdict gets stored as JSON. clientdict_json = json_encoder.encode(clientdict) await self.db_pool.simple_update_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, updatevalues={"clientdict": clientdict_json}, desc="set_ui_auth_client_dict", ) async def set_ui_auth_session_data( self, session_id: str, key: str, value: Any ) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. Args: session_id: The ID of this session as returned from check_auth key: The key to store the data under value: The data to store Raises: StoreError if the session cannot be found. """ await self.db_pool.runInteraction( "set_ui_auth_session_data", self._set_ui_auth_session_data_txn, session_id, key, value, ) def _set_ui_auth_session_data_txn( self, txn: LoggingTransaction, session_id: str, key: str, value: Any ) -> None: # Get the current value. result = cast( Dict[str, Any], self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), ), ) # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) serverdict[key] = value self.db_pool.simple_update_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, updatevalues={"serverdict": json_encoder.encode(serverdict)}, ) async def get_ui_auth_session_data( self, session_id: str, key: str, default: Optional[Any] = None ) -> Any: """ Retrieve data stored with set_session_data Args: session_id: The ID of this session as returned from check_auth key: The key to store the data under default: Value to return if the key has not been set Raises: StoreError if the session cannot be found. """ result = await self.db_pool.simple_select_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), desc="get_ui_auth_session_data", ) serverdict = db_to_json(result["serverdict"]) return serverdict.get(key, default) async def add_user_agent_ip_to_ui_auth_session( self, session_id: str, user_agent: str, ip: str, ) -> None: """Add the given user agent / IP to the tracking table""" await self.db_pool.simple_upsert( table="ui_auth_sessions_ips", keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, values={}, desc="add_user_agent_ip_to_ui_auth_session", ) async def get_user_agents_ips_to_ui_auth_session( self, session_id: str, ) -> List[Tuple[str, str]]: """Get the given user agents / IPs used during the ui auth process Returns: List of user_agent/ip pairs """ rows = await self.db_pool.simple_select_list( table="ui_auth_sessions_ips", keyvalues={"session_id": session_id}, retcols=("user_agent", "ip"), desc="get_user_agents_ips_to_ui_auth_session", ) return [(row["user_agent"], row["ip"]) for row in rows] async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. Args: expiration_time: The latest time that is still considered valid. This is an epoch time in milliseconds. """ await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, ) def _delete_old_ui_auth_sessions_txn( self, txn: LoggingTransaction, expiration_time: int ) -> None: # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] # Delete the corresponding IP/user agents. self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions_ips", column="session_id", values=session_ids, keyvalues={}, ) # If a registration token was used, decrement the pending counter # before deleting the session. rows = cast( List[Tuple[str]], self.db_pool.simple_select_many_txn( txn, table="ui_auth_sessions_credentials", column="session_id", iterable=session_ids, keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, retcols=["result"], ), ) # Get the tokens used and how much pending needs to be decremented by. token_counts: Dict[str, int] = {} for r in rows: # If registration was successfully completed, the result of the # registration token stage for that session will be True. # If a token was used to authenticate, but registration was # never completed, the result will be the token used. token = db_to_json(r[0]) if isinstance(token, str): token_counts[token] = token_counts.get(token, 0) + 1 # Update the `pending` counters. if len(token_counts) > 0: token_rows = cast( List[Tuple[str, int]], self.db_pool.simple_select_many_txn( txn, table="registration_tokens", column="token", iterable=list(token_counts.keys()), keyvalues={}, retcols=["token", "pending"], ), ) for token, pending in token_rows: new_pending = pending - token_counts[token] self.db_pool.simple_update_one_txn( txn, table="registration_tokens", keyvalues={"token": token}, updatevalues={"pending": new_pending}, ) # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions_credentials", column="session_id", values=session_ids, keyvalues={}, ) # Finally, delete the sessions. self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions", column="session_id", values=session_ids, keyvalues={}, ) class UIAuthStore(UIAuthWorkerStore): pass