from tensorflow.python.platform import flags
from tensorflow.contrib.data.python.ops import batching, threadpool
import tensorflow as tf
import json
from torch.utils.data import Dataset
import pickle
import os.path as osp
import os
import numpy as np
import time
from scipy.misc import imread, imresize
from skimage.color import rgb2grey
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
from torchvision import transforms
from imagenet_preprocessing import ImagenetPreprocessor
import torch
import torchvision

FLAGS = flags.FLAGS
ROOT_DIR = "./results"

# Dataset Options
flags.DEFINE_string('dsprites_path',
    '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
    'path to dsprites characters')
flags.DEFINE_string('imagenet_datadir',  '/root/imagenet_big', 'whether cutoff should always in image')
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')


# Data augmentation options
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')


def cutout(mask_color=(0, 0, 0)):
    mask_size_half = FLAGS.cutout_mask_size // 2
    offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0

    def _cutout(image):
        image = np.asarray(image).copy()

        if np.random.random() > FLAGS.cutout_prob:
            return image

        h, w = image.shape[:2]

        if FLAGS.cutout_inside:
            cxmin, cxmax = mask_size_half, w + offset - mask_size_half
            cymin, cymax = mask_size_half, h + offset - mask_size_half
        else:
            cxmin, cxmax = 0, w + offset
            cymin, cymax = 0, h + offset

        cx = np.random.randint(cxmin, cxmax)
        cy = np.random.randint(cymin, cymax)
        xmin = cx - mask_size_half
        ymin = cy - mask_size_half
        xmax = xmin + FLAGS.cutout_mask_size
        ymax = ymin + FLAGS.cutout_mask_size
        xmin = max(0, xmin)
        ymin = max(0, ymin)
        xmax = min(w, xmax)
        ymax = min(h, ymax)
        image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
        return image

    return _cutout


class TFImagenetLoader(Dataset):

    def __init__(self, split, batchsize, idx, num_workers, rescale=1):
        IMAGENET_NUM_TRAIN_IMAGES = 1281167
        IMAGENET_NUM_VAL_IMAGES = 50000

        self.rescale = rescale

        if split == "train":
            im_length = IMAGENET_NUM_TRAIN_IMAGES
            records_to_skip = im_length * idx // num_workers
            records_to_read = im_length * (idx + 1) // num_workers - records_to_skip
        else:
            im_length = IMAGENET_NUM_VAL_IMAGES

        self.curr_sample = 0

        index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
        with open(index_path) as f:
            metadata = json.load(f)
            counts = metadata['record_counts']

        if split == 'train':
            file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))

            result_records_to_skip = None
            files = []
            for filename in file_names:
                records_in_file = counts[filename]
                if records_to_skip >= records_in_file:
                    records_to_skip -= records_in_file
                    continue
                elif records_to_read > 0:
                    if result_records_to_skip is None:
                        # Record the number to skip in the first file
                        result_records_to_skip = records_to_skip
                    files.append(filename)
                    records_to_read -= (records_in_file - records_to_skip)
                    records_to_skip = 0
                else:
                    break
        else:
            files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))

        files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
        preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess

        ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
        ds = ds.apply(tf.data.TFRecordDataset)
        ds = ds.take(im_length)
        ds = ds.prefetch(buffer_size=FLAGS.batch_size)
        ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
        ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
        ds = ds.prefetch(buffer_size=2)

        ds_iterator = ds.make_initializable_iterator()
        labels, images = ds_iterator.get_next()
        self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
        self.labels = labels

        config = tf.ConfigProto(device_count = {'GPU': 0})
        sess = tf.Session(config=config)
        sess.run(ds_iterator.initializer)

        self.im_length = im_length // batchsize

        self.sess = sess

    def __next__(self):
        self.curr_sample += 1

        sess = self.sess

        im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
        label, im = sess.run([self.labels, self.images])
        im = im * self.rescale
        label = np.eye(1000)[label.squeeze() - 1]
        im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
        return im_corrupt, im, label

    def __iter__(self):
        return self

    def __len__(self):
        return self.im_length

class CelebA(Dataset):

    def __init__(self):
        self.path = "/root/data/img_align_celeba"
        self.ims = os.listdir(self.path)
        self.ims = [osp.join(self.path, im) for im in self.ims]

    def __len__(self):
        return len(self.ims)

    def __getitem__(self, index):
        label = 1

        if FLAGS.single:
            index = 0

        path = self.ims[index]
        im = imread(path)
        im = imresize(im, (32, 32))
        image_size = 32
        im = im / 255.

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
        elif FLAGS.datasource == 'random':
            im_corrupt = np.random.uniform(
                0, 1, size=(image_size, image_size, 3))

        return im_corrupt, im, label


class Cifar10(Dataset):
    def __init__(
            self,
            train=True,
            full=False,
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            if FLAGS.cutout:
                transform_list.append(cutout())

            transform = transforms.Compose(transform_list)
        else:
            transform = transforms.ToTensor()

        self.full = full
        self.data = CIFAR10(
            ROOT_DIR,
            transform=transform,
            train=train,
            download=True)
        self.test_data = CIFAR10(
            ROOT_DIR,
            transform=transform,
            train=False,
            download=True)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale

    def __len__(self):

        if self.full:
            return len(self.data) + len(self.test_data)
        else:
            return len(self.data)

    def __getitem__(self, index):
        if not FLAGS.single:
            if self.full:
                if index >= len(self.data):
                    im, label = self.test_data[index - len(self.data)]
                else:
                    im, label = self.data[index]
            else:
                im, label = self.data[index]
        else:
            im, label = self.data[0]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        if self.noise:
            im = im * self.rescale + \
                np.random.uniform(0, self.rescale * 1 / 256., im.shape)

        np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
        elif FLAGS.datasource == 'random':
            im_corrupt = np.random.uniform(
                0.0, self.rescale, (image_size, image_size, 3))

        return im_corrupt, im, label


class Cifar100(Dataset):
    def __init__(self, train=True, augment=False):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            if FLAGS.cutout:
                transform_list.append(cutout())

            transform = transforms.Compose(transform_list)
        else:
            transform = transforms.ToTensor()

        self.data = CIFAR100(
            "/root/cifar100",
            transform=transform,
            train=train,
            download=True)
        self.one_hot_map = np.eye(100)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if not FLAGS.single:
            im, label = self.data[index]
        else:
            im, label = self.data[0]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]
        im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
        np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
        elif FLAGS.datasource == 'random':
            im_corrupt = np.random.uniform(
                0.0, 1.0, (image_size, image_size, 3))

        return im_corrupt, im, label


class Svhn(Dataset):
    def __init__(self, train=True, augment=False):

        transform = transforms.ToTensor()

        self.data = SVHN("/root/svhn", transform=transform, download=True)
        self.one_hot_map = np.eye(10)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if not FLAGS.single:
            im, label = self.data[index]
        else:
            em, label = self.data[0]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]
        im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
        np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
        elif FLAGS.datasource == 'random':
            im_corrupt = np.random.uniform(
                0.0, 1.0, (image_size, image_size, 3))

        return im_corrupt, im, label


class Mnist(Dataset):
    def __init__(self, train=True, rescale=1.0):
        self.data = MNIST(
            "/root/mnist",
            transform=transforms.ToTensor(),
            download=True, train=train)
        self.labels = np.eye(10)
        self.rescale = rescale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        im, label = self.data[index]
        label = self.labels[label]
        im = im.squeeze()
        # im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
        # im = im.numpy() / 2 + 0.2
        im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
        im = im * self.rescale
        image_size = 28

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
        elif FLAGS.datasource == 'random':
            im_corrupt = np.random.uniform(0, self.rescale, (28, 28))

        return im_corrupt, im, label


class DSprites(Dataset):
    def __init__(
            self,
            cond_size=False,
            cond_shape=False,
            cond_pos=False,
            cond_rot=False):
        dat = np.load(FLAGS.dsprites_path)

        if FLAGS.dshape_only:
            l = dat['latents_values']
            mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
                                           31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
            self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
            self.label = np.tile(dat['latents_values'][mask], (10000, 1))
            self.label = self.label[:, 1:2]
        elif FLAGS.dpos_only:
            l = dat['latents_values']
            # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
            mask = (l[:, 1] == 1) & (
                l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
            self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
            self.label = np.tile(dat['latents_values'][mask], (100, 1))
            self.label = self.label[:, 4:] + 0.5
        elif FLAGS.dsize_only:
            l = dat['latents_values']
            # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
            mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
                                                   31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
            self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
            self.label = np.tile(dat['latents_values'][mask], (10000, 1))
            self.label = (self.label[:, 2:3])
        elif FLAGS.drot_only:
            l = dat['latents_values']
            mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
                                       31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
            self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
            self.label = np.tile(dat['latents_values'][mask], (100, 1))
            self.label = (self.label[:, 3:4])
            self.label = np.concatenate(
                [np.cos(self.label), np.sin(self.label)], axis=1)
        elif FLAGS.dsprites_restrict:
            l = dat['latents_values']
            mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)

            self.data = dat['imgs'][mask]
            self.label = dat['latents_values'][mask]
        else:
            self.data = dat['imgs']
            self.label = dat['latents_values']

            if cond_size:
                self.label = self.label[:, 2:3]
            elif cond_shape:
                self.label = self.label[:, 1:2]
            elif cond_pos:
                self.label = self.label[:, 4:]
            elif cond_rot:
                self.label = self.label[:, 3:4]
                self.label = np.concatenate(
                    [np.cos(self.label), np.sin(self.label)], axis=1)
            else:
                self.label = self.label[:, 1:2]

        self.identity = np.eye(3)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        im = self.data[index]
        image_size = 64

        if not (
            FLAGS.dpos_only or FLAGS.dsize_only) and (
            not FLAGS.cond_size) and (
            not FLAGS.cond_pos) and (
                not FLAGS.cond_rot) and (
                    not FLAGS.drot_only):
            label = self.identity[self.label[index].astype(
                np.int32) - 1].squeeze()
        else:
            label = self.label[index]

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
        elif FLAGS.datasource == 'random':
            im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)

        return im_corrupt, im, label


class Imagenet(Dataset):
    def __init__(self, train=True, augment=False):

        if train:
            for i in range(1, 11):
                f = pickle.load(
                    open(
                        osp.join(
                            FLAGS.imagenet_path,
                            'train_data_batch_{}'.format(i)),
                        'rb'))
                if i == 1:
                    labels = f['labels']
                    data = f['data']
                else:
                    labels.extend(f['labels'])
                    data = np.vstack((data, f['data']))
        else:
            f = pickle.load(
                open(
                    osp.join(
                        FLAGS.imagenet_path,
                        'val_data'),
                    'rb'))
            labels = f['labels']
            data = f['data']

        self.labels = labels
        self.data = data
        self.one_hot_map = np.eye(1000)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        if not FLAGS.single:
            im, label = self.data[index], self.labels[index]
        else:
            im, label = self.data[0], self.labels[0]

        label -= 1

        im = im.reshape((3, 32, 32)) / 255
        im = im.transpose((1, 2, 0))
        image_size = 32
        label = self.one_hot_map[label]
        im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
        np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        if FLAGS.datasource == 'default':
            im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
        elif FLAGS.datasource == 'random':
            im_corrupt = np.random.uniform(
                0.0, 1.0, (image_size, image_size, 3))

        return im_corrupt, im, label


class Textures(Dataset):
    def __init__(self, train=True, augment=False):
        self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images")

    def __len__(self):
        return 2 * len(self.dataset)

    def __getitem__(self, index):
        idx = index % (len(self.dataset))
        im, label = self.dataset[idx]

        im = np.array(im)[:32, :32] / 255
        im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)

        return im, im, label