import tensorflow as tf import math from tqdm import tqdm from hmc import hmc from tensorflow.python.platform import flags from torch.utils.data import DataLoader, Dataset from models import DspritesNet from utils import optimistic_restore, ReplayBuffer import os.path as osp import numpy as np from rl_algs.logger import TensorBoardOutputFormat from scipy.misc import imsave import os from custom_adam import AdamOptimizer flags.DEFINE_integer('batch_size', 256, 'Size of inputs') flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things') flags.DEFINE_string('logdir', 'cachedir', 'directory for logging') flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored') flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') flags.DEFINE_float('step_lr', 500, 'size of gradient descent size') flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters') flags.DEFINE_bool('cclass', True, 'not cclass') flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons') flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results') flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.') flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape') flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape') # Conditions on which models to use flags.DEFINE_bool('cond_pos', True, 'whether to condition on position') flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation') flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape') flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale') flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments') flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments') flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments') flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments') flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume') flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume') flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume') flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume') flags.DEFINE_integer('break_steps', 300, 'steps to break') # Whether to train for gentest flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions') FLAGS = flags.FLAGS class DSpritesGen(Dataset): def __init__(self, data, latents, frac=0.0): l = latents if FLAGS.joint_shape: mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) elif FLAGS.joint_rot: mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) else: mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1) mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) data_pos = data[mask_pos] l_pos = l[mask_pos] data_size = data[mask_size] l_size = l[mask_size] n = data_pos.shape[0] // data_size.shape[0] data_pos = np.tile(data_pos, (n, 1, 1)) l_pos = np.tile(l_pos, (n, 1)) self.data = np.concatenate((data_pos, data_size), axis=0) self.label = np.concatenate((l_pos, l_size), axis=0) mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)) data_add = data[mask_neg] l_add = l[mask_neg] perm_idx = np.random.permutation(data_add.shape[0]) select_idx = perm_idx[:int(frac*perm_idx.shape[0])] data_add = data_add[select_idx] l_add = l_add[select_idx] self.data = np.concatenate((self.data, data_add), axis=0) self.label = np.concatenate((self.label, l_add), axis=0) self.identity = np.eye(3) def __len__(self): return self.data.shape[0] def __getitem__(self, index): im = self.data[index] im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64) if FLAGS.joint_shape: label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1] elif FLAGS.joint_rot: label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]) else: label_size = self.label[index, 2:3] label_pos = self.label[index, 4:] return (im_corrupt, im, label_size, label_pos) def labeldiscover(sess, kvs, data, latents, save_exp_dir): LABEL_SIZE = kvs['LABEL_SIZE'] model_size = kvs['model_size'] weight_size = kvs['weight_size'] x_mod = kvs['X_NOISE'] label_output = LABEL_SIZE for i in range(FLAGS.num_steps): label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03) e_noise = model_size.forward(x_mod, weight_size, label=label_output) label_grad = tf.gradients(e_noise, [label_output])[0] # label_grad = tf.Print(label_grad, [label_grad]) label_output = label_output - 1.0 * label_grad label_output = tf.clip_by_value(label_output, 0.5, 1.0) diffs = [] for i in range(30): s = i*FLAGS.batch_size d = (i+1)*FLAGS.batch_size data_i = data[s:d] latent_i = latents[s:d] latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1)) feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init} size_pred = sess.run([label_output], feed_dict)[0] size_gt = latent_i[:, 2:3] diffs.append(np.abs(size_pred - size_gt).mean()) print(np.array(diffs).mean()) def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0): # tf.reset_default_graph() if FLAGS.joint_shape: model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5) LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32) else: model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3) LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac)) X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32) X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL) loss_sq = tf.reduce_mean(tf.square(X_out - X_label)) optimizer = AdamOptimizer(1e-3) gvs = optimizer.compute_gradients(loss_sq) gvs = [(k, v) for (k, v) in gvs if k is not None] train_op = optimizer.apply_gradients(gvs) dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) datafull = data itr = 0 saver = tf.train.Saver() vs = optimizer.variables() sess.run(tf.global_variables_initializer()) if FLAGS.train: for _ in range(5): for data_corrupt, data, label_size, label_pos in tqdm(dataloader): data_corrupt = data_corrupt.numpy() label_size, label_pos = label_size.numpy(), label_pos.numpy() data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters) label_comb = np.concatenate([label_size, label_pos], axis=1) feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb} output = [loss_sq, train_op] loss, _ = sess.run(output, feed_dict=feed_dict) itr += 1 saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline')) saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline')) l = latents if FLAGS.joint_shape: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) else: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) data_gen = datafull[mask_gen] latents_gen = latents[mask_gen] losses = [] for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)): data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters) if FLAGS.joint_shape: latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1] latent_pos = latent[:, 4:6] latent = np.concatenate([latent_size, latent_pos], axis=1) feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat} else: feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat} loss = sess.run([loss_sq], feed_dict=feed_dict)[0] # print(loss) losses.append(loss) print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac)) data_try = data_gen[:10] data_init = np.random.randn(10, 2*FLAGS.num_filters) if FLAGS.joint_shape: latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1] latent_pos = latents_gen[:10, 4:] else: latent_scale = latents_gen[:10, 2:3] latent_pos = latents_gen[:10, 4:] latent_tot = np.concatenate([latent_scale, latent_pos], axis=1) feed_dict = {X_feed: data_init, LABEL: latent_tot} x_output = sess.run([X_out], feed_dict=feed_dict)[0] x_output = np.clip(x_output, 0, 1) im_name = "size_scale_combine_genbaseline.png" x_output_wrap = np.ones((10, 66, 66)) data_try_wrap = np.ones((10, 66, 66)) x_output_wrap[:, 1:-1, 1:-1] = x_output data_try_wrap[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) return np.mean(losses) def gentest(sess, kvs, data, latents, save_exp_dir): X_NOISE = kvs['X_NOISE'] LABEL_SIZE = kvs['LABEL_SIZE'] LABEL_SHAPE = kvs['LABEL_SHAPE'] LABEL_POS = kvs['LABEL_POS'] LABEL_ROT = kvs['LABEL_ROT'] model_size = kvs['model_size'] model_shape = kvs['model_shape'] model_pos = kvs['model_pos'] model_rot = kvs['model_rot'] weight_size = kvs['weight_size'] weight_shape = kvs['weight_shape'] weight_pos = kvs['weight_pos'] weight_rot = kvs['weight_rot'] X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) datafull = data # Test combination of generalization where we use slices of both training x_final = X_NOISE x_mod_size = X_NOISE x_mod_pos = X_NOISE for i in range(FLAGS.num_steps): # use cond_pos energies = [] x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS) # energies.append(e_noise) x_grad = tf.gradients(e_noise, [x_final])[0] x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) if FLAGS.joint_shape: # use cond_shape e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE) elif FLAGS.joint_rot: e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT) else: # use cond_size e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE) # energies.append(e_noise) # energy_stack = tf.concat(energies, axis=1) # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1) # energy_stack = tf.reduce_sum(energy_stack, axis=1) x_grad = tf.gradients(e_noise, [x_mod_pos])[0] x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) # for x_mod_size # use cond_size # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE) # x_grad = tf.gradients(e_noise, [x_mod_size])[0] # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) # # use cond_pos # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS) # x_grad = tf.gradients(e_noise, [x_mod_size])[0] # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad) # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) x_mod = x_mod_pos x_final = x_mod if FLAGS.joint_shape: loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \ model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \ model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \ model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) elif FLAGS.joint_rot: loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \ model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \ model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \ model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) else: loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \ model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \ model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \ model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 neg_loss = coeff * (-1*energy_neg) / norm_constant loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg) loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999) gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] train_op = optimizer.apply_gradients(gvs) vs = optimizer.variables() sess.run(tf.variables_initializer(vs)) dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) x_off = tf.reduce_mean(tf.square(x_mod - X)) itr = 0 saver = tf.train.Saver() x_mod = None if FLAGS.train: replay_buffer = ReplayBuffer(10000) for _ in range(1): for data_corrupt, data, label_size, label_pos in tqdm(dataloader): data_corrupt = data_corrupt.numpy()[:, :, :] data = data.numpy()[:, :, :] if x_mod is not None: replay_buffer.add(x_mod) replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.joint_shape: feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos} elif FLAGS.joint_rot: feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos} else: feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos} _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict) itr += 1 if itr % 10 == 0: print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr)) if itr == FLAGS.break_steps: break saver.save(sess, osp.join(save_exp_dir, 'model_gentest')) saver.restore(sess, osp.join(save_exp_dir, 'model_gentest')) l = latents if FLAGS.joint_shape: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) elif FLAGS.joint_rot: mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5) else: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) data_gen = datafull[mask_gen] latents_gen = latents[mask_gen] losses = [] for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)): x = 0.5 + np.random.randn(*dat.shape) if FLAGS.joint_shape: feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} elif FLAGS.joint_rot: feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} else: feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} for i in range(2): x = sess.run([x_final], feed_dict=feed_dict)[0] feed_dict[X_NOISE] = x loss = sess.run([x_off], feed_dict=feed_dict)[0] losses.append(loss) print("Mean MSE loss of {} ".format(np.mean(losses))) data_try = data_gen[:10] data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) latent_scale = latents_gen[:10, 2:3] latent_pos = latents_gen[:10, 4:] if FLAGS.joint_shape: feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos} elif FLAGS.joint_rot: feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init} else: feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos} x_output = sess.run([x_final], feed_dict=feed_dict)[0] if FLAGS.joint_shape: im_name = "size_shape_combine_gentest.png" else: im_name = "size_scale_combine_gentest.png" x_output_wrap = np.ones((10, 66, 66)) data_try_wrap = np.ones((10, 66, 66)) x_output_wrap[:, 1:-1, 1:-1] = x_output data_try_wrap[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) def conceptcombine(sess, kvs, data, latents, save_exp_dir): X_NOISE = kvs['X_NOISE'] LABEL_SIZE = kvs['LABEL_SIZE'] LABEL_SHAPE = kvs['LABEL_SHAPE'] LABEL_POS = kvs['LABEL_POS'] LABEL_ROT = kvs['LABEL_ROT'] model_size = kvs['model_size'] model_shape = kvs['model_shape'] model_pos = kvs['model_pos'] model_rot = kvs['model_rot'] weight_size = kvs['weight_size'] weight_shape = kvs['weight_shape'] weight_pos = kvs['weight_pos'] weight_rot = kvs['weight_rot'] x_mod = X_NOISE for i in range(FLAGS.num_steps): if FLAGS.cond_scale: e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) if FLAGS.cond_shape: e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) if FLAGS.cond_pos: e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) if FLAGS.cond_rot: e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) print("Finished constructing loop {}".format(i)) x_final = x_mod data_try = data[:10] data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) label_scale = latents[:10, 2:3] label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)] label_rot = latents[:10, 3:4] label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1) label_pos = latents[:10, 4:] feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos, LABEL_ROT: label_rot} x_out = sess.run([x_final], feed_dict)[0] im_name = "im" if FLAGS.cond_scale: im_name += "_condscale" if FLAGS.cond_shape: im_name += "_condshape" if FLAGS.cond_pos: im_name += "_condpos" if FLAGS.cond_rot: im_name += "_condrot" im_name += ".png" x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66)) x_out_pad[:, 1:-1, 1:-1] = x_out data_try_pad[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) def main(): data = np.load(FLAGS.dsprites_path)['imgs'] l = latents = np.load(FLAGS.dsprites_path)['latents_values'] np.random.seed(1) idx = np.random.permutation(data.shape[0]) data = data[idx] latents = latents[idx] config = tf.ConfigProto() sess = tf.Session(config=config) # Model 1 will be conditioned on size model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True) weight_size = model_size.construct_weights('context_0') # Model 2 will be conditioned on shape model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True) weight_shape = model_shape.construct_weights('context_1') # Model 3 will be conditioned on position model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True) weight_pos = model_pos.construct_weights('context_2') # Model 4 will be conditioned on rotation model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True) weight_rot = model_rot.construct_weights('context_3') sess.run(tf.global_variables_initializer()) save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size)) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0)) v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list} if FLAGS.cond_scale: saver = tf.train.Saver(v_map) saver.restore(sess, save_path_size) save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape)) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} if FLAGS.cond_shape: saver = tf.train.Saver(v_map) saver.restore(sess, save_path_shape) save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos)) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) if FLAGS.cond_pos: saver.restore(sess, save_path_pos) save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot)) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3)) v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) if FLAGS.cond_rot: saver.restore(sess, save_path_rot) X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32) x_mod = X_NOISE kvs = {} kvs['X_NOISE'] = X_NOISE kvs['LABEL_SIZE'] = LABEL_SIZE kvs['LABEL_SHAPE'] = LABEL_SHAPE kvs['LABEL_POS'] = LABEL_POS kvs['LABEL_ROT'] = LABEL_ROT kvs['model_size'] = model_size kvs['model_shape'] = model_shape kvs['model_pos'] = model_pos kvs['model_rot'] = model_rot kvs['weight_size'] = weight_size kvs['weight_shape'] = weight_shape kvs['weight_pos'] = weight_pos kvs['weight_rot'] = weight_rot save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape)) if not osp.exists(save_exp_dir): os.makedirs(save_exp_dir) if FLAGS.task == 'conceptcombine': conceptcombine(sess, kvs, data, latents, save_exp_dir) elif FLAGS.task == 'labeldiscover': labeldiscover(sess, kvs, data, latents, save_exp_dir) elif FLAGS.task == 'gentest': save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos)) if not osp.exists(save_exp_dir): os.makedirs(save_exp_dir) gentest(sess, kvs, data, latents, save_exp_dir) elif FLAGS.task == 'genbaseline': save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos)) if not osp.exists(save_exp_dir): os.makedirs(save_exp_dir) if FLAGS.plot_curve: mse_losses = [] for frac in [i/10 for i in range(11)]: mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac) mse_losses.append(mse_loss) np.save("mse_baseline_comb.npy", mse_losses) else: genbaseline(sess, kvs, data, latents, save_exp_dir) if __name__ == "__main__": main()