# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import enum
import threading
from typing import (
    Callable,
    Collection,
    Dict,
    Generic,
    MutableMapping,
    Optional,
    Set,
    Sized,
    Tuple,
    TypeVar,
    Union,
    cast,
)

from prometheus_client import Gauge

from twisted.internet import defer
from twisted.python.failure import Failure

from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry

cache_pending_metric = Gauge(
    "synapse_util_caches_cache_pending",
    "Number of lookups currently pending for this cache",
    ["name"],
)

T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")


class _Sentinel(enum.Enum):
    # defining a sentinel in this way allows mypy to correctly handle the
    # type of a dictionary lookup.
    sentinel = object()


class DeferredCache(Generic[KT, VT]):
    """Wraps an LruCache, adding support for Deferred results.

    It expects that each entry added with set() will be a Deferred; likewise get()
    will return a Deferred.
    """

    __slots__ = (
        "cache",
        "thread",
        "_pending_deferred_cache",
    )

    def __init__(
        self,
        name: str,
        max_entries: int = 1000,
        tree: bool = False,
        iterable: bool = False,
        apply_cache_factor_from_config: bool = True,
        prune_unread_entries: bool = True,
    ):
        """
        Args:
            name: The name of the cache
            max_entries: Maximum amount of entries that the cache will hold
            tree: Use a TreeCache instead of a dict as the underlying cache type
            iterable: If True, count each item in the cached object as an entry,
                rather than each cached object
            apply_cache_factor_from_config: Whether cache factors specified in the
                config file affect `max_entries`
            prune_unread_entries: If True, cache entries that haven't been read recently
                will be evicted from the cache in the background. Set to False to
                opt-out of this behaviour.
        """
        cache_type = TreeCache if tree else dict

        # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
        self._pending_deferred_cache: Union[
            TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
        ] = cache_type()

        def metrics_cb() -> None:
            cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))

        # cache is used for completed results and maps to the result itself, rather than
        # a Deferred.
        self.cache: LruCache[KT, VT] = LruCache(
            max_size=max_entries,
            cache_name=name,
            cache_type=cache_type,
            size_callback=(
                (lambda d: len(cast(Sized, d)) or 1)
                # Argument 1 to "len" has incompatible type "VT"; expected "Sized"
                # We trust that `VT` is `Sized` when `iterable` is `True`
                if iterable
                else None
            ),
            metrics_collection_callback=metrics_cb,
            apply_cache_factor_from_config=apply_cache_factor_from_config,
            prune_unread_entries=prune_unread_entries,
        )

        self.thread: Optional[threading.Thread] = None

    @property
    def max_entries(self) -> int:
        return self.cache.max_size

    def check_thread(self) -> None:
        expected_thread = self.thread
        if expected_thread is None:
            self.thread = threading.current_thread()
        else:
            if expected_thread is not threading.current_thread():
                raise ValueError(
                    "Cache objects can only be accessed from the main thread"
                )

    def get(
        self,
        key: KT,
        callback: Optional[Callable[[], None]] = None,
        update_metrics: bool = True,
    ) -> defer.Deferred:
        """Looks the key up in the caches.

        For symmetry with set(), this method does *not* follow the synapse logcontext
        rules: the logcontext will not be cleared on return, and the Deferred will run
        its callbacks in the sentinel context. In other words: wrap the result with
        make_deferred_yieldable() before `await`ing it.

        Args:
            key:
            callback: Gets called when the entry in the cache is invalidated
            update_metrics: whether to update the cache hit rate metrics

        Returns:
            A Deferred which completes with the result. Note that this may later fail
            if there is an ongoing set() operation which later completes with a failure.

        Raises:
            KeyError if the key is not found in the cache
        """
        val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
        if val is not _Sentinel.sentinel:
            val.add_invalidation_callback(key, callback)
            if update_metrics:
                m = self.cache.metrics
                assert m  # we always have a name, so should always have metrics
                m.inc_hits()
            return val.deferred(key)

        callbacks = (callback,) if callback else ()

        val2 = self.cache.get(
            key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
        )
        if val2 is _Sentinel.sentinel:
            raise KeyError()
        else:
            return defer.succeed(val2)

    def get_bulk(
        self,
        keys: Collection[KT],
        callback: Optional[Callable[[], None]] = None,
    ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
        """Bulk lookup of items in the cache.

        Returns:
            A 3-tuple of:
                1. a dict of key/value of items already cached;
                2. a deferred that resolves to a dict of key/value of items
                   we're already fetching; and
                3. a collection of keys that don't appear in the previous two.
        """

        # The cached results
        cached = {}

        # List of pending deferreds
        pending = []

        # Dict that gets filled out when the pending deferreds complete
        pending_results = {}

        # List of keys that aren't in either cache
        missing = []

        callbacks = (callback,) if callback else ()

        for key in keys:
            # Check if its in the main cache.
            immediate_value = self.cache.get(
                key,
                _Sentinel.sentinel,
                callbacks=callbacks,
            )
            if immediate_value is not _Sentinel.sentinel:
                cached[key] = immediate_value
                continue

            # Check if its in the pending cache
            pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
            if pending_value is not _Sentinel.sentinel:
                pending_value.add_invalidation_callback(key, callback)

                def completed_cb(value: VT, key: KT) -> VT:
                    pending_results[key] = value
                    return value

                # Add a callback to fill out `pending_results` when that completes
                d = pending_value.deferred(key).addCallback(completed_cb, key)
                pending.append(d)
                continue

            # Not in either cache
            missing.append(key)

        # If we've got pending deferreds, squash them into a single one that
        # returns `pending_results`.
        pending_deferred = None
        if pending:
            pending_deferred = defer.gatherResults(
                pending, consumeErrors=True
            ).addCallback(lambda _: pending_results)

        return (cached, pending_deferred, missing)

    def get_immediate(
        self, key: KT, default: T, update_metrics: bool = True
    ) -> Union[VT, T]:
        """If we have a *completed* cached value, return it."""
        return self.cache.get(key, default, update_metrics=update_metrics)

    def set(
        self,
        key: KT,
        value: "defer.Deferred[VT]",
        callback: Optional[Callable[[], None]] = None,
    ) -> defer.Deferred:
        """Adds a new entry to the cache (or updates an existing one).

        The given `value` *must* be a Deferred.

        First any existing entry for the same key is invalidated. Then a new entry
        is added to the cache for the given key.

        Until the `value` completes, calls to `get()` for the key will also result in an
        incomplete Deferred, which will ultimately complete with the same result as
        `value`.

        If `value` completes successfully, subsequent calls to `get()` will then return
        a completed deferred with the same result. If it *fails*, the cache is
        invalidated and subequent calls to `get()` will raise a KeyError.

        If another call to `set()` happens before `value` completes, then (a) any
        invalidation callbacks registered in the interim will be called, (b) any
        `get()`s in the interim will continue to complete with the result from the
        *original* `value`, (c) any future calls to `get()` will complete with the
        result from the *new* `value`.

        It is expected that `value` does *not* follow the synapse logcontext rules - ie,
        if it is incomplete, it runs its callbacks in the sentinel context.

        Args:
            key: Key to be set
            value: a deferred which will complete with a result to add to the cache
            callback: An optional callback to be called when the entry is invalidated
        """
        self.check_thread()

        self._pending_deferred_cache.pop(key, None)

        # XXX: why don't we invalidate the entry in `self.cache` yet?

        # otherwise, we'll add an entry to the _pending_deferred_cache for now,
        # and add callbacks to add it to the cache properly later.
        entry = CacheEntrySingle[KT, VT](value)
        entry.add_invalidation_callback(key, callback)
        self._pending_deferred_cache[key] = entry
        deferred = entry.deferred(key).addCallbacks(
            self._completed_callback,
            self._error_callback,
            callbackArgs=(entry, key),
            errbackArgs=(entry, key),
        )

        # we return a new Deferred which will be called before any subsequent observers.
        return deferred

    def start_bulk_input(
        self,
        keys: Collection[KT],
        callback: Optional[Callable[[], None]] = None,
    ) -> "CacheMultipleEntries[KT, VT]":
        """Bulk set API for use when fetching multiple keys at once from the DB.

        Called *before* starting the fetch from the DB, and the caller *must*
        call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
        """

        entry = CacheMultipleEntries[KT, VT]()
        entry.add_global_invalidation_callback(callback)

        for key in keys:
            self._pending_deferred_cache[key] = entry

        return entry

    def _completed_callback(
        self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
    ) -> VT:
        """Called when a deferred is completed."""
        # We check if the current entry matches the entry associated with the
        # deferred. If they don't match then it got invalidated.
        current_entry = self._pending_deferred_cache.pop(key, None)
        if current_entry is not entry:
            if current_entry:
                self._pending_deferred_cache[key] = current_entry
            return value

        self.cache.set(key, value, entry.get_invalidation_callbacks(key))

        return value

    def _error_callback(
        self,
        failure: Failure,
        entry: "CacheEntry[KT, VT]",
        key: KT,
    ) -> Failure:
        """Called when a deferred errors."""

        # We check if the current entry matches the entry associated with the
        # deferred. If they don't match then it got invalidated.
        current_entry = self._pending_deferred_cache.pop(key, None)
        if current_entry is not entry:
            if current_entry:
                self._pending_deferred_cache[key] = current_entry
            return failure

        for cb in entry.get_invalidation_callbacks(key):
            cb()

        return failure

    def prefill(
        self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
    ) -> None:
        callbacks = (callback,) if callback else ()
        self.cache.set(key, value, callbacks=callbacks)
        self._pending_deferred_cache.pop(key, None)

    def invalidate(self, key: KT) -> None:
        """Delete a key, or tree of entries

        If the cache is backed by a regular dict, then "key" must be of
        the right type for this cache

        If the cache is backed by a TreeCache, then "key" must be a tuple, but
        may be of lower cardinality than the TreeCache - in which case the whole
        subtree is deleted.
        """
        self.check_thread()
        self.cache.del_multi(key)

        # if we have a pending lookup for this key, remove it from the
        # _pending_deferred_cache, which will (a) stop it being returned for
        # future queries and (b) stop it being persisted as a proper entry
        # in self.cache.
        entry = self._pending_deferred_cache.pop(key, None)
        if entry:
            # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
            # case of a TreeCache, a dict of keys to cache entries. Either way calling
            # iterate_tree_cache_entry on it will do the right thing.
            for iter_entry in iterate_tree_cache_entry(entry):
                for cb in iter_entry.get_invalidation_callbacks(key):
                    cb()

    def invalidate_all(self) -> None:
        self.check_thread()
        self.cache.clear()
        for key, entry in self._pending_deferred_cache.items():
            for cb in entry.get_invalidation_callbacks(key):
                cb()

        self._pending_deferred_cache.clear()


class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
    """Abstract class for entries in `DeferredCache[KT, VT]`"""

    @abc.abstractmethod
    def deferred(self, key: KT) -> "defer.Deferred[VT]":
        """Get a deferred that a caller can wait on to get the value at the
        given key"""
        ...

    @abc.abstractmethod
    def add_invalidation_callback(
        self, key: KT, callback: Optional[Callable[[], None]]
    ) -> None:
        """Add an invalidation callback"""
        ...

    @abc.abstractmethod
    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
        """Get all invalidation callbacks"""
        ...


class CacheEntrySingle(CacheEntry[KT, VT]):
    """An implementation of `CacheEntry` wrapping a deferred that results in a
    single cache entry.
    """

    __slots__ = ["_deferred", "_callbacks"]

    def __init__(self, deferred: "defer.Deferred[VT]") -> None:
        self._deferred = ObservableDeferred(deferred, consumeErrors=True)
        self._callbacks: Set[Callable[[], None]] = set()

    def deferred(self, key: KT) -> "defer.Deferred[VT]":
        return self._deferred.observe()

    def add_invalidation_callback(
        self, key: KT, callback: Optional[Callable[[], None]]
    ) -> None:
        if callback is None:
            return

        self._callbacks.add(callback)

    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
        return self._callbacks


class CacheMultipleEntries(CacheEntry[KT, VT]):
    """Cache entry that is used for bulk lookups and insertions."""

    __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]

    def __init__(self) -> None:
        self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
        self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
        self._global_callbacks: Set[Callable[[], None]] = set()

    def deferred(self, key: KT) -> "defer.Deferred[VT]":
        if not self._deferred:
            self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
        return self._deferred.observe().addCallback(lambda res: res.get(key))

    def add_invalidation_callback(
        self, key: KT, callback: Optional[Callable[[], None]]
    ) -> None:
        if callback is None:
            return

        self._callbacks.setdefault(key, set()).add(callback)

    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
        return self._callbacks.get(key, set()) | self._global_callbacks

    def add_global_invalidation_callback(
        self, callback: Optional[Callable[[], None]]
    ) -> None:
        """Add a callback for when any keys get invalidated."""
        if callback is None:
            return

        self._global_callbacks.add(callback)

    def complete_bulk(
        self,
        cache: DeferredCache[KT, VT],
        result: Dict[KT, VT],
    ) -> None:
        """Called when there is a result"""
        for key, value in result.items():
            cache._completed_callback(value, self, key)

        if self._deferred:
            self._deferred.callback(result)

    def error_bulk(
        self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
    ) -> None:
        """Called when bulk lookup failed."""
        for key in keys:
            cache._error_callback(failure, self, key)

        if self._deferred:
            self._deferred.errback(failure)