import tensorflow as tf
import math
from tqdm import tqdm
from hmc import hmc
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader, Dataset
from models import DspritesNet
from utils import optimistic_restore, ReplayBuffer
import os.path as osp
import numpy as np
from rl_algs.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
from custom_adam import AdamOptimizer

flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
flags.DEFINE_bool('cclass', True, 'not cclass')
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')

# Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')

flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
flags.DEFINE_integer('break_steps', 300, 'steps to break')

# Whether to train for gentest
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')

FLAGS = flags.FLAGS

class DSpritesGen(Dataset):
    def __init__(self, data, latents, frac=0.0):

        l = latents

        if FLAGS.joint_shape:
            mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
        elif FLAGS.joint_rot:
            mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
        else:
            mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)

        mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)

        data_pos = data[mask_pos]
        l_pos = l[mask_pos]

        data_size = data[mask_size]
        l_size = l[mask_size]

        n = data_pos.shape[0] // data_size.shape[0]

        data_pos = np.tile(data_pos, (n, 1, 1))
        l_pos = np.tile(l_pos, (n, 1))

        self.data = np.concatenate((data_pos, data_size), axis=0)
        self.label = np.concatenate((l_pos, l_size), axis=0)

        mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
        data_add = data[mask_neg]
        l_add = l[mask_neg]

        perm_idx = np.random.permutation(data_add.shape[0])
        select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
        data_add = data_add[select_idx]
        l_add = l_add[select_idx]

        self.data = np.concatenate((self.data, data_add), axis=0)
        self.label = np.concatenate((self.label, l_add), axis=0)

        self.identity = np.eye(3)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        im = self.data[index]
        im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64)

        if FLAGS.joint_shape:
            label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
        elif FLAGS.joint_rot:
            label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
        else:
            label_size = self.label[index, 2:3]

        label_pos = self.label[index, 4:]

        return (im_corrupt, im, label_size, label_pos)


def labeldiscover(sess, kvs, data, latents, save_exp_dir):
    LABEL_SIZE = kvs['LABEL_SIZE']
    model_size = kvs['model_size']
    weight_size = kvs['weight_size']
    x_mod = kvs['X_NOISE']

    label_output = LABEL_SIZE
    for i in range(FLAGS.num_steps):
        label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
        e_noise = model_size.forward(x_mod, weight_size, label=label_output)
        label_grad = tf.gradients(e_noise, [label_output])[0]
        # label_grad = tf.Print(label_grad, [label_grad])
        label_output = label_output - 1.0 * label_grad
        label_output = tf.clip_by_value(label_output, 0.5, 1.0)

    diffs = []
    for i in range(30):
        s = i*FLAGS.batch_size
        d = (i+1)*FLAGS.batch_size
        data_i = data[s:d]
        latent_i = latents[s:d]
        latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))

        feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
        size_pred = sess.run([label_output], feed_dict)[0]
        size_gt = latent_i[:, 2:3]

        diffs.append(np.abs(size_pred - size_gt).mean())

    print(np.array(diffs).mean())


def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
    # tf.reset_default_graph()

    if FLAGS.joint_shape:
        model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5)
        LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32)
    else:
        model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
        LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)

    weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))

    X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
    X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

    X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
    loss_sq = tf.reduce_mean(tf.square(X_out - X_label))

    optimizer = AdamOptimizer(1e-3)
    gvs = optimizer.compute_gradients(loss_sq)
    gvs = [(k, v) for (k, v) in gvs if k is not None]
    train_op = optimizer.apply_gradients(gvs)

    dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)

    datafull = data

    itr = 0
    saver = tf.train.Saver()

    vs = optimizer.variables()
    sess.run(tf.global_variables_initializer())

    if FLAGS.train:
        for _ in range(5):
            for data_corrupt, data, label_size, label_pos in tqdm(dataloader):

                data_corrupt = data_corrupt.numpy()
                label_size, label_pos = label_size.numpy(), label_pos.numpy()

                data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
                label_comb = np.concatenate([label_size, label_pos], axis=1)

                feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}

                output = [loss_sq, train_op]

                loss, _ = sess.run(output, feed_dict=feed_dict)

                itr += 1

        saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))

    saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))

    l = latents

    if FLAGS.joint_shape:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
    else:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))

    data_gen = datafull[mask_gen]
    latents_gen = latents[mask_gen]
    losses = []

    for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
        data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)

        if FLAGS.joint_shape:
            latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
            latent_pos = latent[:, 4:6]
            latent = np.concatenate([latent_size, latent_pos], axis=1)
            feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
        else:
            feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
        loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
        # print(loss)
        losses.append(loss)

    print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))


    data_try = data_gen[:10]
    data_init = np.random.randn(10, 2*FLAGS.num_filters)

    if FLAGS.joint_shape:
        latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
        latent_pos = latents_gen[:10, 4:]
    else:
        latent_scale = latents_gen[:10, 2:3]
        latent_pos = latents_gen[:10, 4:]

    latent_tot = np.concatenate([latent_scale, latent_pos], axis=1)

    feed_dict = {X_feed: data_init, LABEL: latent_tot}
    x_output = sess.run([X_out], feed_dict=feed_dict)[0]
    x_output = np.clip(x_output, 0, 1)

    im_name = "size_scale_combine_genbaseline.png"

    x_output_wrap = np.ones((10, 66, 66))
    data_try_wrap = np.ones((10, 66, 66))

    x_output_wrap[:, 1:-1, 1:-1] = x_output
    data_try_wrap[:, 1:-1, 1:-1] = data_try

    im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
    impath = osp.join(save_exp_dir, im_name)
    imsave(impath, im_output)
    print("Successfully saved images at {}".format(impath))

    return np.mean(losses)


def gentest(sess, kvs, data, latents, save_exp_dir):
    X_NOISE = kvs['X_NOISE']
    LABEL_SIZE = kvs['LABEL_SIZE']
    LABEL_SHAPE = kvs['LABEL_SHAPE']
    LABEL_POS = kvs['LABEL_POS']
    LABEL_ROT = kvs['LABEL_ROT']
    model_size = kvs['model_size']
    model_shape = kvs['model_shape']
    model_pos = kvs['model_pos']
    model_rot = kvs['model_rot']
    weight_size = kvs['weight_size']
    weight_shape = kvs['weight_shape']
    weight_pos = kvs['weight_pos']
    weight_rot = kvs['weight_rot']
    X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

    datafull = data
    # Test combination of generalization where we use slices of both training
    x_final = X_NOISE
    x_mod_size = X_NOISE
    x_mod_pos = X_NOISE

    for i in range(FLAGS.num_steps):

        # use cond_pos

        energies = []
        x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
        e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)

        # energies.append(e_noise)
        x_grad = tf.gradients(e_noise, [x_final])[0]
        x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
        x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
        x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)

        if FLAGS.joint_shape:
            # use cond_shape
            e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE)
        elif FLAGS.joint_rot:
            e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT)
        else:
            # use cond_size
            e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE)

        # energies.append(e_noise)
        # energy_stack = tf.concat(energies, axis=1)
        # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1)
        # energy_stack = tf.reduce_sum(energy_stack, axis=1)

        x_grad = tf.gradients(e_noise, [x_mod_pos])[0]
        x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
        x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)

        # for x_mod_size
        # use cond_size
        # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE)
        # x_grad = tf.gradients(e_noise, [x_mod_size])[0]
        # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
        # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad
        # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)

        # # use cond_pos
        # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS)
        # x_grad = tf.gradients(e_noise, [x_mod_size])[0]
        # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
        # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad)
        # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)

    x_mod = x_mod_pos
    x_final = x_mod


    if FLAGS.joint_shape:
        loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
                  model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)

        energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
                      model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)

        energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
                      model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
    elif FLAGS.joint_rot:
        loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
                  model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)

        energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
                      model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)

        energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
                      model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
    else:
        loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
                    model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)

        energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
                      model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)

        energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
                      model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)

    energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
    coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
    norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
    neg_loss = coeff * (-1*energy_neg) / norm_constant

    loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
    loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))

    optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
    gvs = optimizer.compute_gradients(loss_total)
    gvs = [(k, v) for (k, v) in gvs if k is not None]
    train_op = optimizer.apply_gradients(gvs)

    vs = optimizer.variables()
    sess.run(tf.variables_initializer(vs))

    dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)

    x_off = tf.reduce_mean(tf.square(x_mod - X))

    itr = 0
    saver = tf.train.Saver()
    x_mod = None


    if FLAGS.train:
        replay_buffer = ReplayBuffer(10000)
        for _ in range(1):


            for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
                data_corrupt = data_corrupt.numpy()[:, :, :]
                data = data.numpy()[:, :, :]

                if x_mod is not None:
                    replay_buffer.add(x_mod)
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

                if FLAGS.joint_shape:
                    feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
                elif FLAGS.joint_rot:
                    feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
                else:
                    feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}

                _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
                itr += 1

                if itr % 10 == 0:
                    print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))

                if itr == FLAGS.break_steps:
                    break


        saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))

    saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))

    l = latents

    if FLAGS.joint_shape:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
    elif FLAGS.joint_rot:
        mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
    else:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))

    data_gen = datafull[mask_gen]
    latents_gen = latents[mask_gen]

    losses = []

    for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
        x = 0.5 + np.random.randn(*dat.shape)

        if FLAGS.joint_shape:
            feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
        elif FLAGS.joint_rot:
            feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
        else:
            feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}

        for i in range(2):
            x = sess.run([x_final], feed_dict=feed_dict)[0]
            feed_dict[X_NOISE] = x

        loss = sess.run([x_off], feed_dict=feed_dict)[0]
        losses.append(loss)

    print("Mean MSE loss of {} ".format(np.mean(losses)))

    data_try = data_gen[:10]
    data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
    latent_scale = latents_gen[:10, 2:3]
    latent_pos = latents_gen[:10, 4:]

    if FLAGS.joint_shape:
        feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
    elif FLAGS.joint_rot:
        feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
    else:
        feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}

    x_output = sess.run([x_final], feed_dict=feed_dict)[0]

    if FLAGS.joint_shape:
        im_name = "size_shape_combine_gentest.png"
    else:
        im_name = "size_scale_combine_gentest.png"

    x_output_wrap = np.ones((10, 66, 66))
    data_try_wrap = np.ones((10, 66, 66))

    x_output_wrap[:, 1:-1, 1:-1] = x_output
    data_try_wrap[:, 1:-1, 1:-1] = data_try

    im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
    impath = osp.join(save_exp_dir, im_name)
    imsave(impath, im_output)
    print("Successfully saved images at {}".format(impath))



def conceptcombine(sess, kvs, data, latents, save_exp_dir):
    X_NOISE = kvs['X_NOISE']
    LABEL_SIZE = kvs['LABEL_SIZE']
    LABEL_SHAPE = kvs['LABEL_SHAPE']
    LABEL_POS = kvs['LABEL_POS']
    LABEL_ROT = kvs['LABEL_ROT']
    model_size = kvs['model_size']
    model_shape = kvs['model_shape']
    model_pos = kvs['model_pos']
    model_rot = kvs['model_rot']
    weight_size = kvs['weight_size']
    weight_shape = kvs['weight_shape']
    weight_pos = kvs['weight_pos']
    weight_rot = kvs['weight_rot']

    x_mod = X_NOISE
    for i in range(FLAGS.num_steps):

        if FLAGS.cond_scale:
            e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE)
            x_grad = tf.gradients(e_noise, [x_mod])[0]
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
            x_mod = x_mod - FLAGS.step_lr * x_grad
            x_mod = tf.clip_by_value(x_mod, 0, 1)

        if FLAGS.cond_shape:
            e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE)
            x_grad = tf.gradients(e_noise, [x_mod])[0]
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
            x_mod = x_mod - FLAGS.step_lr * x_grad
            x_mod = tf.clip_by_value(x_mod, 0, 1)

        if FLAGS.cond_pos:
            e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS)
            x_grad = tf.gradients(e_noise, [x_mod])[0]
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
            x_mod = x_mod - FLAGS.step_lr * x_grad
            x_mod = tf.clip_by_value(x_mod, 0, 1)

        if FLAGS.cond_rot:
            e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT)
            x_grad = tf.gradients(e_noise, [x_mod])[0]
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
            x_mod = x_mod - FLAGS.step_lr * x_grad
            x_mod = tf.clip_by_value(x_mod, 0, 1)

        print("Finished constructing loop {}".format(i))

    x_final = x_mod

    data_try = data[:10]
    data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
    label_scale = latents[:10, 2:3]
    label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
    label_rot = latents[:10, 3:4]
    label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
    label_pos = latents[:10, 4:]

    feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
                 LABEL_ROT: label_rot}
    x_out = sess.run([x_final], feed_dict)[0]

    im_name = "im"

    if FLAGS.cond_scale:
        im_name += "_condscale"

    if FLAGS.cond_shape:
        im_name += "_condshape"

    if FLAGS.cond_pos:
        im_name += "_condpos"

    if FLAGS.cond_rot:
        im_name += "_condrot"

    im_name += ".png"

    x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66))
    x_out_pad[:, 1:-1, 1:-1] = x_out
    data_try_pad[:, 1:-1, 1:-1] = data_try

    im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
    impath = osp.join(save_exp_dir, im_name)
    imsave(impath, im_output)
    print("Successfully saved images at {}".format(impath))

def main():
    data = np.load(FLAGS.dsprites_path)['imgs']
    l = latents = np.load(FLAGS.dsprites_path)['latents_values']

    np.random.seed(1)
    idx = np.random.permutation(data.shape[0])

    data = data[idx]
    latents = latents[idx]

    config = tf.ConfigProto()
    sess = tf.Session(config=config)

    # Model 1 will be conditioned on size
    model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
    weight_size = model_size.construct_weights('context_0')

    # Model 2 will be conditioned on shape
    model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
    weight_shape = model_shape.construct_weights('context_1')

    # Model 3 will be conditioned on position
    model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
    weight_pos = model_pos.construct_weights('context_2')

    # Model 4 will be conditioned on rotation
    model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
    weight_rot = model_rot.construct_weights('context_3')

    sess.run(tf.global_variables_initializer())
    save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))

    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
    v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}

    if FLAGS.cond_scale:
        saver = tf.train.Saver(v_map)
        saver.restore(sess, save_path_size)

    save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))

    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
    v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}

    if FLAGS.cond_shape:
        saver = tf.train.Saver(v_map)
        saver.restore(sess, save_path_shape)


    save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
    v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
    saver = tf.train.Saver(v_map)

    if FLAGS.cond_pos:
        saver.restore(sess, save_path_pos)


    save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
    v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
    saver = tf.train.Saver(v_map)

    if FLAGS.cond_rot:
        saver.restore(sess, save_path_rot)

    X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
    LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32)
    LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32)
    LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
    LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)

    x_mod = X_NOISE

    kvs = {}
    kvs['X_NOISE'] = X_NOISE
    kvs['LABEL_SIZE'] = LABEL_SIZE
    kvs['LABEL_SHAPE'] = LABEL_SHAPE
    kvs['LABEL_POS'] = LABEL_POS
    kvs['LABEL_ROT'] = LABEL_ROT
    kvs['model_size'] = model_size
    kvs['model_shape'] = model_shape
    kvs['model_pos'] = model_pos
    kvs['model_rot'] = model_rot
    kvs['weight_size'] = weight_size
    kvs['weight_shape'] = weight_shape
    kvs['weight_pos'] = weight_pos
    kvs['weight_rot'] = weight_rot

    save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
    if not osp.exists(save_exp_dir):
        os.makedirs(save_exp_dir)


    if FLAGS.task == 'conceptcombine':
        conceptcombine(sess, kvs, data, latents, save_exp_dir)
    elif FLAGS.task == 'labeldiscover':
        labeldiscover(sess, kvs, data, latents, save_exp_dir)
    elif FLAGS.task == 'gentest':
        save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
        if not osp.exists(save_exp_dir):
            os.makedirs(save_exp_dir)

        gentest(sess, kvs, data, latents, save_exp_dir)
    elif FLAGS.task == 'genbaseline':
        save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
        if not osp.exists(save_exp_dir):
            os.makedirs(save_exp_dir)

        if FLAGS.plot_curve:
            mse_losses = []
            for frac in [i/10 for i in range(11)]:
                mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
                mse_losses.append(mse_loss)
            np.save("mse_baseline_comb.npy", mse_losses)
        else:
            genbaseline(sess, kvs, data, latents, save_exp_dir)



if __name__ == "__main__":
    main()