import os import os.path as osp import numpy as np import tensorflow as tf from custom_adam import AdamOptimizer from models import DspritesNet from scipy.misc import imsave from tensorflow.python.platform import flags from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from utils import ReplayBuffer flags.DEFINE_integer("batch_size", 256, "Size of inputs") flags.DEFINE_integer("data_workers", 4, "Number of workers to do things") flags.DEFINE_string("logdir", "cachedir", "directory for logging") flags.DEFINE_string( "savedir", "cachedir", "location where log of experiments will be stored" ) flags.DEFINE_integer( "num_filters", 64, "number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.", ) flags.DEFINE_float("step_lr", 500, "size of gradient descent size") flags.DEFINE_string( "dsprites_path", "/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz", "path to dsprites characters", ) flags.DEFINE_bool("cclass", True, "not cclass") flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons") flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights") flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution") flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network") flags.DEFINE_bool("plot_curve", False, "Generate a curve of results") flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label") flags.DEFINE_string( "task", "conceptcombine", "conceptcombine, labeldiscover, gentest, genbaseline, etc.", ) flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape") flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape") # Conditions on which models to use flags.DEFINE_bool("cond_pos", True, "whether to condition on position") flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation") flags.DEFINE_bool("cond_shape", True, "whether to condition on shape") flags.DEFINE_bool("cond_scale", True, "whether to condition on scale") flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments") flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments") flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments") flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments") flags.DEFINE_integer("resume_size", 169000, "First iteration to resume") flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume") flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume") flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume") flags.DEFINE_integer("break_steps", 300, "steps to break") # Whether to train for gentest flags.DEFINE_bool( "train", False, "whether to train on generalization into multiple different predictions", ) FLAGS = flags.FLAGS class DSpritesGen(Dataset): def __init__(self, data, latents, frac=0.0): l = latents if FLAGS.joint_shape: mask_size = ( (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31) & (l[:, 2] == 0.5) ) elif FLAGS.joint_rot: mask_size = ( (l[:, 1] == 1) & (l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31) & (l[:, 2] == 0.5) ) else: mask_size = ( (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) ) mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) data_pos = data[mask_pos] l_pos = l[mask_pos] data_size = data[mask_size] l_size = l[mask_size] n = data_pos.shape[0] // data_size.shape[0] data_pos = np.tile(data_pos, (n, 1, 1)) l_pos = np.tile(l_pos, (n, 1)) self.data = np.concatenate((data_pos, data_size), axis=0) self.label = np.concatenate((l_pos, l_size), axis=0) mask_neg = (~(mask_size & mask_pos)) & ( (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) ) data_add = data[mask_neg] l_add = l[mask_neg] perm_idx = np.random.permutation(data_add.shape[0]) select_idx = perm_idx[: int(frac * perm_idx.shape[0])] data_add = data_add[select_idx] l_add = l_add[select_idx] self.data = np.concatenate((self.data, data_add), axis=0) self.label = np.concatenate((self.label, l_add), axis=0) self.identity = np.eye(3) def __len__(self): return self.data.shape[0] def __getitem__(self, index): im = self.data[index] im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64) if FLAGS.joint_shape: label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1] elif FLAGS.joint_rot: label_size = np.array( [np.cos(self.label[index, 3]), np.sin(self.label[index, 3])] ) else: label_size = self.label[index, 2:3] label_pos = self.label[index, 4:] return (im_corrupt, im, label_size, label_pos) def labeldiscover(sess, kvs, data, latents, save_exp_dir): LABEL_SIZE = kvs["LABEL_SIZE"] model_size = kvs["model_size"] weight_size = kvs["weight_size"] x_mod = kvs["X_NOISE"] label_output = LABEL_SIZE for i in range(FLAGS.num_steps): label_output = label_output + tf.random_normal( tf.shape(label_output), mean=0.0, stddev=0.03 ) e_noise = model_size.forward(x_mod, weight_size, label=label_output) label_grad = tf.gradients(e_noise, [label_output])[0] # label_grad = tf.Print(label_grad, [label_grad]) label_output = label_output - 1.0 * label_grad label_output = tf.clip_by_value(label_output, 0.5, 1.0) diffs = [] for i in range(30): s = i * FLAGS.batch_size d = (i + 1) * FLAGS.batch_size data_i = data[s:d] latent_i = latents[s:d] latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1)) feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init} size_pred = sess.run([label_output], feed_dict)[0] size_gt = latent_i[:, 2:3] diffs.append(np.abs(size_pred - size_gt).mean()) print(np.array(diffs).mean()) def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0): # tf.reset_default_graph() if FLAGS.joint_shape: model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5) LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32) else: model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3) LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) weights_baseline = model_baseline.construct_weights( "context_baseline_{}".format(frac) ) X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32) X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL) loss_sq = tf.reduce_mean(tf.square(X_out - X_label)) optimizer = AdamOptimizer(1e-3) gvs = optimizer.compute_gradients(loss_sq) gvs = [(k, v) for (k, v) in gvs if k is not None] train_op = optimizer.apply_gradients(gvs) dataloader = DataLoader( DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True, ) datafull = data itr = 0 saver = tf.train.Saver() optimizer.variables() sess.run(tf.global_variables_initializer()) if FLAGS.train: for _ in range(5): for data_corrupt, data, label_size, label_pos in tqdm(dataloader): data_corrupt = data_corrupt.numpy() label_size, label_pos = label_size.numpy(), label_pos.numpy() data_corrupt = np.random.randn( data_corrupt.shape[0], 2 * FLAGS.num_filters ) label_comb = np.concatenate([label_size, label_pos], axis=1) feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb} output = [loss_sq, train_op] loss, _ = sess.run(output, feed_dict=feed_dict) itr += 1 saver.save(sess, osp.join(save_exp_dir, "model_genbaseline")) saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline")) l = latents if FLAGS.joint_shape: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) else: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & ( ~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31))) ) data_gen = datafull[mask_gen] latents_gen = latents[mask_gen] losses = [] for dat, latent in zip( np.array_split(data_gen, 10), np.array_split(latents_gen, 10) ): data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters) if FLAGS.joint_shape: latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1] latent_pos = latent[:, 4:6] latent = np.concatenate([latent_size, latent_pos], axis=1) feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat} else: feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat} loss = sess.run([loss_sq], feed_dict=feed_dict)[0] # print(loss) losses.append(loss) print( "Overall MSE for generalization of {} for fraction of {}".format( np.mean(losses), frac ) ) data_try = data_gen[:10] data_init = np.random.randn(10, 2 * FLAGS.num_filters) if FLAGS.joint_shape: latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1] latent_pos = latents_gen[:10, 4:] else: latent_scale = latents_gen[:10, 2:3] latent_pos = latents_gen[:10, 4:] latent_tot = np.concatenate([latent_scale, latent_pos], axis=1) feed_dict = {X_feed: data_init, LABEL: latent_tot} x_output = sess.run([X_out], feed_dict=feed_dict)[0] x_output = np.clip(x_output, 0, 1) im_name = "size_scale_combine_genbaseline.png" x_output_wrap = np.ones((10, 66, 66)) data_try_wrap = np.ones((10, 66, 66)) x_output_wrap[:, 1:-1, 1:-1] = x_output data_try_wrap[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape( -1, 66 * 2 ) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) return np.mean(losses) def gentest(sess, kvs, data, latents, save_exp_dir): X_NOISE = kvs["X_NOISE"] LABEL_SIZE = kvs["LABEL_SIZE"] LABEL_SHAPE = kvs["LABEL_SHAPE"] LABEL_POS = kvs["LABEL_POS"] LABEL_ROT = kvs["LABEL_ROT"] model_size = kvs["model_size"] model_shape = kvs["model_shape"] model_pos = kvs["model_pos"] model_rot = kvs["model_rot"] weight_size = kvs["weight_size"] weight_shape = kvs["weight_shape"] weight_pos = kvs["weight_pos"] weight_rot = kvs["weight_rot"] X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) datafull = data # Test combination of generalization where we use slices of both training x_final = X_NOISE x_mod_pos = X_NOISE for i in range(FLAGS.num_steps): # use cond_pos x_mod_pos = x_mod_pos + tf.random_normal( tf.shape(x_mod_pos), mean=0.0, stddev=0.005 ) e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS) # energies.append(e_noise) x_grad = tf.gradients(e_noise, [x_final])[0] x_mod_pos = x_mod_pos + tf.random_normal( tf.shape(x_mod_pos), mean=0.0, stddev=0.005 ) x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) if FLAGS.joint_shape: # use cond_shape e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE) elif FLAGS.joint_rot: e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT) else: # use cond_size e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE) # energies.append(e_noise) # energy_stack = tf.concat(energies, axis=1) # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1) # energy_stack = tf.reduce_sum(energy_stack, axis=1) x_grad = tf.gradients(e_noise, [x_mod_pos])[0] x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) # for x_mod_size # use cond_size # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE) # x_grad = tf.gradients(e_noise, [x_mod_size])[0] # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) # # use cond_pos # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS) # x_grad = tf.gradients(e_noise, [x_mod_size])[0] # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad) # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) x_mod = x_mod_pos x_final = x_mod if FLAGS.joint_shape: loss_kl = model_shape.forward( x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True ) + model_pos.forward( x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True ) energy_pos = model_shape.forward( X, weight_shape, reuse=True, label=LABEL_SHAPE ) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_shape.forward( tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE ) + model_pos.forward( tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS ) elif FLAGS.joint_rot: loss_kl = model_rot.forward( x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True ) + model_pos.forward( x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True ) energy_pos = model_rot.forward( X, weight_rot, reuse=True, label=LABEL_ROT ) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_rot.forward( tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT ) + model_pos.forward( tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS ) else: loss_kl = model_size.forward( x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True ) + model_pos.forward( x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True ) energy_pos = model_size.forward( X, weight_size, reuse=True, label=LABEL_SIZE ) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_size.forward( tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE ) + model_pos.forward( tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS ) energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg) coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 coeff * (-1 * energy_neg) / norm_constant loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg) loss_total = ( loss_ml + tf.reduce_mean(loss_kl) + 1 * ( tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)) ) ) optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999) gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] train_op = optimizer.apply_gradients(gvs) vs = optimizer.variables() sess.run(tf.variables_initializer(vs)) dataloader = DataLoader( DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True, ) x_off = tf.reduce_mean(tf.square(x_mod - X)) itr = 0 saver = tf.train.Saver() x_mod = None if FLAGS.train: replay_buffer = ReplayBuffer(10000) for _ in range(1): for data_corrupt, data, label_size, label_pos in tqdm(dataloader): data_corrupt = data_corrupt.numpy()[:, :, :] data = data.numpy()[:, :, :] if x_mod is not None: replay_buffer.add(x_mod) replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95 data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.joint_shape: feed_dict = { X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos, } elif FLAGS.joint_rot: feed_dict = { X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos, } else: feed_dict = { X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos, } _, off_value, e_pos, e_neg, x_mod = sess.run( [train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict, ) itr += 1 if itr % 10 == 0: print( "x_off of {}, e_pos of {}, e_neg of {} itr of {}".format( off_value, e_pos.mean(), e_neg.mean(), itr ) ) if itr == FLAGS.break_steps: break saver.save(sess, osp.join(save_exp_dir, "model_gentest")) saver.restore(sess, osp.join(save_exp_dir, "model_gentest")) l = latents if FLAGS.joint_shape: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) elif FLAGS.joint_rot: mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5) else: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & ( ~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31))) ) data_gen = datafull[mask_gen] latents_gen = latents[mask_gen] losses = [] for dat, latent in zip( np.array_split(data_gen, 120), np.array_split(latents_gen, 120) ): x = 0.5 + np.random.randn(*dat.shape) if FLAGS.joint_shape: feed_dict = { LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat, } elif FLAGS.joint_rot: feed_dict = { LABEL_ROT: np.concatenate( [np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1 ), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat, } else: feed_dict = { LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat, } for i in range(2): x = sess.run([x_final], feed_dict=feed_dict)[0] feed_dict[X_NOISE] = x loss = sess.run([x_off], feed_dict=feed_dict)[0] losses.append(loss) print("Mean MSE loss of {} ".format(np.mean(losses))) data_try = data_gen[:10] data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) latent_scale = latents_gen[:10, 2:3] latent_pos = latents_gen[:10, 4:] if FLAGS.joint_shape: feed_dict = { X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1], LABEL_POS: latent_pos, } elif FLAGS.joint_rot: feed_dict = { LABEL_ROT: np.concatenate( [np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1 ), LABEL_POS: latent[:10, 4:], X_NOISE: data_init, } else: feed_dict = { X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos, } x_output = sess.run([x_final], feed_dict=feed_dict)[0] if FLAGS.joint_shape: im_name = "size_shape_combine_gentest.png" else: im_name = "size_scale_combine_gentest.png" x_output_wrap = np.ones((10, 66, 66)) data_try_wrap = np.ones((10, 66, 66)) x_output_wrap[:, 1:-1, 1:-1] = x_output data_try_wrap[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape( -1, 66 * 2 ) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) def conceptcombine(sess, kvs, data, latents, save_exp_dir): X_NOISE = kvs["X_NOISE"] LABEL_SIZE = kvs["LABEL_SIZE"] LABEL_SHAPE = kvs["LABEL_SHAPE"] LABEL_POS = kvs["LABEL_POS"] LABEL_ROT = kvs["LABEL_ROT"] model_size = kvs["model_size"] model_shape = kvs["model_shape"] model_pos = kvs["model_pos"] model_rot = kvs["model_rot"] weight_size = kvs["weight_size"] weight_shape = kvs["weight_shape"] weight_pos = kvs["weight_pos"] weight_rot = kvs["weight_rot"] x_mod = X_NOISE for i in range(FLAGS.num_steps): if FLAGS.cond_scale: e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) if FLAGS.cond_shape: e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) if FLAGS.cond_pos: e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) if FLAGS.cond_rot: e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT) x_grad = tf.gradients(e_noise, [x_mod])[0] x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) print("Finished constructing loop {}".format(i)) x_final = x_mod data_try = data[:10] data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) label_scale = latents[:10, 2:3] label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)] label_rot = latents[:10, 3:4] label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1) label_pos = latents[:10, 4:] feed_dict = { X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos, LABEL_ROT: label_rot, } x_out = sess.run([x_final], feed_dict)[0] im_name = "im" if FLAGS.cond_scale: im_name += "_condscale" if FLAGS.cond_shape: im_name += "_condshape" if FLAGS.cond_pos: im_name += "_condpos" if FLAGS.cond_rot: im_name += "_condrot" im_name += ".png" x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66)) x_out_pad[:, 1:-1, 1:-1] = x_out data_try_pad[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) def main(): data = np.load(FLAGS.dsprites_path)["imgs"] l = latents = np.load(FLAGS.dsprites_path)["latents_values"] np.random.seed(1) idx = np.random.permutation(data.shape[0]) data = data[idx] latents = latents[idx] config = tf.ConfigProto() sess = tf.Session(config=config) # Model 1 will be conditioned on size model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True) weight_size = model_size.construct_weights("context_0") # Model 2 will be conditioned on shape model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True) weight_shape = model_shape.construct_weights("context_1") # Model 3 will be conditioned on position model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True) weight_pos = model_pos.construct_weights("context_2") # Model 4 will be conditioned on rotation model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True) weight_rot = model_rot.construct_weights("context_3") sess.run(tf.global_variables_initializer()) save_path_size = osp.join( FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size) ) v_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0) ) v_map = { (v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list } if FLAGS.cond_scale: saver = tf.train.Saver(v_map) saver.restore(sess, save_path_size) save_path_shape = osp.join( FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape) ) v_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1) ) v_map = { (v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list } if FLAGS.cond_shape: saver = tf.train.Saver(v_map) saver.restore(sess, save_path_shape) save_path_pos = osp.join( FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos) ) v_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2) ) v_map = { (v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list } saver = tf.train.Saver(v_map) if FLAGS.cond_pos: saver.restore(sess, save_path_pos) save_path_rot = osp.join( FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot) ) v_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3) ) v_map = { (v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list } saver = tf.train.Saver(v_map) if FLAGS.cond_rot: saver.restore(sess, save_path_rot) X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32) kvs = {} kvs["X_NOISE"] = X_NOISE kvs["LABEL_SIZE"] = LABEL_SIZE kvs["LABEL_SHAPE"] = LABEL_SHAPE kvs["LABEL_POS"] = LABEL_POS kvs["LABEL_ROT"] = LABEL_ROT kvs["model_size"] = model_size kvs["model_shape"] = model_shape kvs["model_pos"] = model_pos kvs["model_rot"] = model_rot kvs["weight_size"] = weight_size kvs["weight_shape"] = weight_shape kvs["weight_pos"] = weight_pos kvs["weight_rot"] = weight_rot save_exp_dir = osp.join( FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape) ) if not osp.exists(save_exp_dir): os.makedirs(save_exp_dir) if FLAGS.task == "conceptcombine": conceptcombine(sess, kvs, data, latents, save_exp_dir) elif FLAGS.task == "labeldiscover": labeldiscover(sess, kvs, data, latents, save_exp_dir) elif FLAGS.task == "gentest": save_exp_dir = osp.join( FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos) ) if not osp.exists(save_exp_dir): os.makedirs(save_exp_dir) gentest(sess, kvs, data, latents, save_exp_dir) elif FLAGS.task == "genbaseline": save_exp_dir = osp.join( FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos) ) if not osp.exists(save_exp_dir): os.makedirs(save_exp_dir) if FLAGS.plot_curve: mse_losses = [] for frac in [i / 10 for i in range(11)]: mse_loss = genbaseline( sess, kvs, data, latents, save_exp_dir, frac=frac ) mse_losses.append(mse_loss) np.save("mse_baseline_comb.npy", mse_losses) else: genbaseline(sess, kvs, data, latents, save_exp_dir) if __name__ == "__main__": main()