#  Copyright 2021 The Matrix.org Foundation C.I.C.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from typing import TYPE_CHECKING

import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
    DatabasePool,
    LoggingDatabaseConnection,
    LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder

if TYPE_CHECKING:
    from synapse.server import HomeServer


class SessionStore(SQLBaseStore):
    """
    A store for generic session data.

    Each type of session should provide a unique type (to separate sessions).

    Sessions are automatically removed when they expire.
    """

    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        # Create a background job for culling expired sessions.
        if hs.config.worker.run_background_tasks:
            self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)

    async def create_session(
        self, session_type: str, value: JsonDict, expiry_ms: int
    ) -> str:
        """
        Creates a new pagination session for the room hierarchy endpoint.

        Args:
            session_type: The type for this session.
            value: The value to store.
            expiry_ms: How long before an item is evicted from the cache
                in milliseconds. Default is 0, indicating items never get
                evicted based on time.

        Returns:
            The newly created session ID.

        Raises:
            StoreError if a unique session ID cannot be generated.
        """
        # autogen a session ID and try to create it. We may clash, so just
        # try a few times till one goes through, giving up eventually.
        attempts = 0
        while attempts < 5:
            session_id = stringutils.random_string(24)

            try:
                await self.db_pool.simple_insert(
                    table="sessions",
                    values={
                        "session_id": session_id,
                        "session_type": session_type,
                        "value": json_encoder.encode(value),
                        "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
                    },
                    desc="create_session",
                )

                return session_id
            except self.db_pool.engine.module.IntegrityError:
                attempts += 1
        raise StoreError(500, "Couldn't generate a session ID.")

    async def get_session(self, session_type: str, session_id: str) -> JsonDict:
        """
        Retrieve data stored with create_session

        Args:
            session_type: The type for this session.
            session_id: The session ID returned from create_session.

        Raises:
            StoreError if the session cannot be found.
        """

        def _get_session(
            txn: LoggingTransaction, session_type: str, session_id: str, ts: int
        ) -> JsonDict:
            # This includes the expiry time since items are only periodically
            # deleted, not upon expiry.
            select_sql = """
            SELECT value FROM sessions WHERE
            session_type = ? AND session_id = ? AND expiry_time_ms > ?
            """
            txn.execute(select_sql, [session_type, session_id, ts])
            row = txn.fetchone()

            if not row:
                raise StoreError(404, "No session")

            return db_to_json(row[0])

        return await self.db_pool.runInteraction(
            "get_session",
            _get_session,
            session_type,
            session_id,
            self._clock.time_msec(),
        )

    @wrap_as_background_process("delete_expired_sessions")
    async def _delete_expired_sessions(self) -> None:
        """Remove sessions with expiry dates that have passed."""

        def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
            sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
            txn.execute(sql, (ts,))

        await self.db_pool.runInteraction(
            "delete_expired_sessions",
            _delete_expired_sessions_txn,
            self._clock.time_msec(),
        )