# Copyright 2021 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary

from twisted.internet.interfaces import IReactorCore

from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
    DatabasePool,
    LoggingDatabaseConnection,
    LoggingTransaction,
)
from synapse.util import Clock
from synapse.util.stringutils import random_string

if TYPE_CHECKING:
    from synapse.server import HomeServer


logger = logging.getLogger(__name__)


# How often to renew an acquired lock by updating the `last_renewed_ts` time in
# the lock table.
_RENEWAL_INTERVAL_MS = 30 * 1000

# How long before an acquired lock times out.
_LOCK_TIMEOUT_MS = 2 * 60 * 1000


class LockStore(SQLBaseStore):
    """Provides a best effort distributed lock between worker instances.

    Locks are identified by a name and key. A lock is acquired by inserting into
    the `worker_locks` table if a) there is no existing row for the name/key or
    b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`.

    When a lock is taken out the instance inserts a random `token`, the instance
    that holds that token holds the lock until it drops (or times out).

    The instance that holds the lock should regularly update the
    `last_renewed_ts` column with the current time.
    """

    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._reactor = hs.get_reactor()
        self._instance_name = hs.get_instance_id()

        # A map from `(lock_name, lock_key)` to the token of any locks that we
        # think we currently hold.
        self._live_tokens: WeakValueDictionary[
            Tuple[str, str], Lock
        ] = WeakValueDictionary()

        # When we shut down we want to remove the locks. Technically this can
        # lead to a race, as we may drop the lock while we are still processing.
        # However, a) it should be a small window, b) the lock is best effort
        # anyway and c) we want to really avoid leaking locks when we restart.
        hs.get_reactor().addSystemEventTrigger(
            "before",
            "shutdown",
            self._on_shutdown,
        )

        self._acquiring_locks: Set[Tuple[str, str]] = set()

    @wrap_as_background_process("LockStore._on_shutdown")
    async def _on_shutdown(self) -> None:
        """Called when the server is shutting down"""
        logger.info("Dropping held locks due to shutdown")

        # We need to take a copy of the tokens dict as dropping the locks will
        # cause the dictionary to change.
        locks = dict(self._live_tokens)

        for lock in locks.values():
            await lock.release()

        logger.info("Dropped locks due to shutdown")

    async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]:
        """Try to acquire a lock for the given name/key. Will return an async
        context manager if the lock is successfully acquired, which *must* be
        used (otherwise the lock will leak).
        """
        if (lock_name, lock_key) in self._acquiring_locks:
            return None
        try:
            self._acquiring_locks.add((lock_name, lock_key))
            return await self._try_acquire_lock(lock_name, lock_key)
        finally:
            self._acquiring_locks.discard((lock_name, lock_key))

    async def _try_acquire_lock(
        self, lock_name: str, lock_key: str
    ) -> Optional["Lock"]:
        """Try to acquire a lock for the given name/key. Will return an async
        context manager if the lock is successfully acquired, which *must* be
        used (otherwise the lock will leak).
        """

        # Check if this process has taken out a lock and if it's still valid.
        lock = self._live_tokens.get((lock_name, lock_key))
        if lock and await lock.is_still_valid():
            return None

        now = self._clock.time_msec()
        token = random_string(6)

        def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
            # We take out the lock if either a) there is no row for the lock
            # already, b) the existing row has timed out, or c) the row is
            # for this instance (which means the process got killed and
            # restarted)
            sql = """
               INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
               VALUES (?, ?, ?, ?, ?)
               ON CONFLICT (lock_name, lock_key)
               DO UPDATE
                   SET
                       token = EXCLUDED.token,
                       instance_name = EXCLUDED.instance_name,
                       last_renewed_ts = EXCLUDED.last_renewed_ts
                   WHERE
                       worker_locks.last_renewed_ts < ?
                       OR worker_locks.instance_name = EXCLUDED.instance_name
           """
            txn.execute(
                sql,
                (
                    lock_name,
                    lock_key,
                    self._instance_name,
                    token,
                    now,
                    now - _LOCK_TIMEOUT_MS,
                ),
            )

            # We only acquired the lock if we inserted or updated the table.
            return bool(txn.rowcount)

        did_lock = await self.db_pool.runInteraction(
            "try_acquire_lock",
            _try_acquire_lock_txn,
            # We can autocommit here as we're executing a single query, this
            # will avoid serialization errors.
            db_autocommit=True,
        )
        if not did_lock:
            return None

        lock = Lock(
            self._reactor,
            self._clock,
            self,
            lock_name=lock_name,
            lock_key=lock_key,
            token=token,
        )

        self._live_tokens[(lock_name, lock_key)] = lock

        return lock

    async def _is_lock_still_valid(
        self, lock_name: str, lock_key: str, token: str
    ) -> bool:
        """Checks whether this instance still holds the lock."""
        last_renewed_ts = await self.db_pool.simple_select_one_onecol(
            table="worker_locks",
            keyvalues={
                "lock_name": lock_name,
                "lock_key": lock_key,
                "token": token,
            },
            retcol="last_renewed_ts",
            allow_none=True,
            desc="is_lock_still_valid",
        )
        return (
            last_renewed_ts is not None
            and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
        )

    async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None:
        """Attempt to renew the lock if we still hold it."""
        await self.db_pool.simple_update(
            table="worker_locks",
            keyvalues={
                "lock_name": lock_name,
                "lock_key": lock_key,
                "token": token,
            },
            updatevalues={"last_renewed_ts": self._clock.time_msec()},
            desc="renew_lock",
        )

    async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None:
        """Attempt to drop the lock, if we still hold it"""
        await self.db_pool.simple_delete(
            table="worker_locks",
            keyvalues={
                "lock_name": lock_name,
                "lock_key": lock_key,
                "token": token,
            },
            desc="drop_lock",
        )

        self._live_tokens.pop((lock_name, lock_key), None)


class Lock:
    """An async context manager that manages an acquired lock, ensuring it is
    regularly renewed and dropping it when the context manager exits.

    The lock object has an `is_still_valid` method which can be used to
    double-check the lock is still valid, if e.g. processing work in a loop.

    For example:

        lock = await self.store.try_acquire_lock(...)
        if not lock:
            return

        async with lock:
            for item in work:
                await process(item)

                if not await lock.is_still_valid():
                    break
    """

    def __init__(
        self,
        reactor: IReactorCore,
        clock: Clock,
        store: LockStore,
        lock_name: str,
        lock_key: str,
        token: str,
    ) -> None:
        self._reactor = reactor
        self._clock = clock
        self._store = store
        self._lock_name = lock_name
        self._lock_key = lock_key

        self._token = token

        self._looping_call = clock.looping_call(
            self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token
        )

        self._dropped = False

    @staticmethod
    @wrap_as_background_process("Lock._renew")
    async def _renew(
        store: LockStore,
        lock_name: str,
        lock_key: str,
        token: str,
    ) -> None:
        """Renew the lock.

        Note: this is a static method, rather than using self.*, so that we
        don't end up with a reference to `self` in the reactor, which would stop
        this from being cleaned up if we dropped the context manager.
        """
        await store._renew_lock(lock_name, lock_key, token)

    async def is_still_valid(self) -> bool:
        """Check if the lock is still held by us"""
        return await self._store._is_lock_still_valid(
            self._lock_name, self._lock_key, self._token
        )

    async def __aenter__(self) -> None:
        if self._dropped:
            raise Exception("Cannot reuse a Lock object")

    async def __aexit__(
        self,
        _exctype: Optional[Type[BaseException]],
        _excinst: Optional[BaseException],
        _exctb: Optional[TracebackType],
    ) -> bool:
        await self.release()

        return False

    async def release(self) -> None:
        """Release the lock.

        This is automatically called when using the lock as a context manager.
        """

        if self._dropped:
            return

        if self._looping_call.running:
            self._looping_call.stop()

        await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
        self._dropped = True

    def __del__(self) -> None:
        if not self._dropped:
            # We should not be dropped without the lock being released (unless
            # we're shutting down), but if we are then let's at least stop
            # renewing the lock.
            if self._looping_call.running:
                self._looping_call.stop()

            if self._reactor.running:
                logger.error(
                    "Lock for (%s, %s) dropped without being released",
                    self._lock_name,
                    self._lock_key,
                )