from typing import Dict

from six import itervalues

SENTINEL = object()


class TreeCache(object):
    """
    Tree-based backing store for LruCache. Allows subtrees of data to be deleted
    efficiently.
    Keys must be tuples.
    """

    def __init__(self):
        self.size = 0
        self.root = {}  # type: Dict

    def __setitem__(self, key, value):
        return self.set(key, value)

    def __contains__(self, key):
        return self.get(key, SENTINEL) is not SENTINEL

    def set(self, key, value):
        node = self.root
        for k in key[:-1]:
            node = node.setdefault(k, {})
        node[key[-1]] = _Entry(value)
        self.size += 1

    def get(self, key, default=None):
        node = self.root
        for k in key[:-1]:
            node = node.get(k, None)
            if node is None:
                return default
        return node.get(key[-1], _Entry(default)).value

    def clear(self):
        self.size = 0
        self.root = {}

    def pop(self, key, default=None):
        nodes = []

        node = self.root
        for k in key[:-1]:
            node = node.get(k, None)
            nodes.append(node)  # don't add the root node
            if node is None:
                return default
        popped = node.pop(key[-1], SENTINEL)
        if popped is SENTINEL:
            return default

        node_and_keys = list(zip(nodes, key))
        node_and_keys.reverse()
        node_and_keys.append((self.root, None))

        for i in range(len(node_and_keys) - 1):
            n, k = node_and_keys[i]

            if n:
                break
            node_and_keys[i + 1][0].pop(k)

        popped, cnt = _strip_and_count_entires(popped)
        self.size -= cnt
        return popped

    def values(self):
        return list(iterate_tree_cache_entry(self.root))

    def __len__(self):
        return self.size


def iterate_tree_cache_entry(d):
    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
    can contain dicts.
    """
    if isinstance(d, dict):
        for value_d in itervalues(d):
            for value in iterate_tree_cache_entry(value_d):
                yield value
    else:
        if isinstance(d, _Entry):
            yield d.value
        else:
            yield d


class _Entry(object):
    __slots__ = ["value"]

    def __init__(self, value):
        self.value = value


def _strip_and_count_entires(d):
    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
    value or a dictionary with _Entry's replaced by their values.

    Also returns the count of _Entry's
    """
    if isinstance(d, dict):
        cnt = 0
        for key, value in d.items():
            v, n = _strip_and_count_entires(value)
            d[key] = v
            cnt += n
        return d, cnt
    else:
        return d.value, 1