mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-16 02:00:32 -04:00
878 lines
30 KiB
Python
878 lines
30 KiB
Python
import os
|
|
import os.path as osp
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from custom_adam import AdamOptimizer
|
|
from models import DspritesNet
|
|
from scipy.misc import imsave
|
|
from tensorflow.python.platform import flags
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from tqdm import tqdm
|
|
from utils import ReplayBuffer
|
|
|
|
flags.DEFINE_integer("batch_size", 256, "Size of inputs")
|
|
flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
|
|
flags.DEFINE_string("logdir", "cachedir", "directory for logging")
|
|
flags.DEFINE_string(
|
|
"savedir", "cachedir", "location where log of experiments will be stored"
|
|
)
|
|
flags.DEFINE_integer(
|
|
"num_filters",
|
|
64,
|
|
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
|
)
|
|
flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
|
|
flags.DEFINE_string(
|
|
"dsprites_path",
|
|
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
|
"path to dsprites characters",
|
|
)
|
|
flags.DEFINE_bool("cclass", True, "not cclass")
|
|
flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
|
|
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
|
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
|
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
|
flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
|
|
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
|
flags.DEFINE_string(
|
|
"task",
|
|
"conceptcombine",
|
|
"conceptcombine, labeldiscover, gentest, genbaseline, etc.",
|
|
)
|
|
flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
|
|
flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
|
|
|
|
# Conditions on which models to use
|
|
flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
|
|
flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
|
|
flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
|
|
flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
|
|
|
|
flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
|
|
flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
|
|
flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
|
|
flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
|
|
flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
|
|
flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
|
|
flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
|
|
flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
|
|
flags.DEFINE_integer("break_steps", 300, "steps to break")
|
|
|
|
# Whether to train for gentest
|
|
flags.DEFINE_bool(
|
|
"train",
|
|
False,
|
|
"whether to train on generalization into multiple different predictions",
|
|
)
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
class DSpritesGen(Dataset):
|
|
def __init__(self, data, latents, frac=0.0):
|
|
|
|
l = latents
|
|
|
|
if FLAGS.joint_shape:
|
|
mask_size = (
|
|
(l[:, 3] == 30 * np.pi / 39)
|
|
& (l[:, 4] == 16 / 31)
|
|
& (l[:, 5] == 16 / 31)
|
|
& (l[:, 2] == 0.5)
|
|
)
|
|
elif FLAGS.joint_rot:
|
|
mask_size = (
|
|
(l[:, 1] == 1)
|
|
& (l[:, 4] == 16 / 31)
|
|
& (l[:, 5] == 16 / 31)
|
|
& (l[:, 2] == 0.5)
|
|
)
|
|
else:
|
|
mask_size = (
|
|
(l[:, 3] == 30 * np.pi / 39)
|
|
& (l[:, 4] == 16 / 31)
|
|
& (l[:, 5] == 16 / 31)
|
|
& (l[:, 1] == 1)
|
|
)
|
|
|
|
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
|
|
|
data_pos = data[mask_pos]
|
|
l_pos = l[mask_pos]
|
|
|
|
data_size = data[mask_size]
|
|
l_size = l[mask_size]
|
|
|
|
n = data_pos.shape[0] // data_size.shape[0]
|
|
|
|
data_pos = np.tile(data_pos, (n, 1, 1))
|
|
l_pos = np.tile(l_pos, (n, 1))
|
|
|
|
self.data = np.concatenate((data_pos, data_size), axis=0)
|
|
self.label = np.concatenate((l_pos, l_size), axis=0)
|
|
|
|
mask_neg = (~(mask_size & mask_pos)) & (
|
|
(l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
|
|
)
|
|
data_add = data[mask_neg]
|
|
l_add = l[mask_neg]
|
|
|
|
perm_idx = np.random.permutation(data_add.shape[0])
|
|
select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
|
|
data_add = data_add[select_idx]
|
|
l_add = l_add[select_idx]
|
|
|
|
self.data = np.concatenate((self.data, data_add), axis=0)
|
|
self.label = np.concatenate((self.label, l_add), axis=0)
|
|
|
|
self.identity = np.eye(3)
|
|
|
|
def __len__(self):
|
|
return self.data.shape[0]
|
|
|
|
def __getitem__(self, index):
|
|
im = self.data[index]
|
|
im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64)
|
|
|
|
if FLAGS.joint_shape:
|
|
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
|
|
elif FLAGS.joint_rot:
|
|
label_size = np.array(
|
|
[np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
|
|
)
|
|
else:
|
|
label_size = self.label[index, 2:3]
|
|
|
|
label_pos = self.label[index, 4:]
|
|
|
|
return (im_corrupt, im, label_size, label_pos)
|
|
|
|
|
|
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
|
LABEL_SIZE = kvs["LABEL_SIZE"]
|
|
model_size = kvs["model_size"]
|
|
weight_size = kvs["weight_size"]
|
|
x_mod = kvs["X_NOISE"]
|
|
|
|
label_output = LABEL_SIZE
|
|
for i in range(FLAGS.num_steps):
|
|
label_output = label_output + tf.random_normal(
|
|
tf.shape(label_output), mean=0.0, stddev=0.03
|
|
)
|
|
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
|
|
label_grad = tf.gradients(e_noise, [label_output])[0]
|
|
# label_grad = tf.Print(label_grad, [label_grad])
|
|
label_output = label_output - 1.0 * label_grad
|
|
label_output = tf.clip_by_value(label_output, 0.5, 1.0)
|
|
|
|
diffs = []
|
|
for i in range(30):
|
|
s = i * FLAGS.batch_size
|
|
d = (i + 1) * FLAGS.batch_size
|
|
data_i = data[s:d]
|
|
latent_i = latents[s:d]
|
|
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
|
|
|
|
feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
|
|
size_pred = sess.run([label_output], feed_dict)[0]
|
|
size_gt = latent_i[:, 2:3]
|
|
|
|
diffs.append(np.abs(size_pred - size_gt).mean())
|
|
|
|
print(np.array(diffs).mean())
|
|
|
|
|
|
def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|
# tf.reset_default_graph()
|
|
|
|
if FLAGS.joint_shape:
|
|
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5)
|
|
LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32)
|
|
else:
|
|
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
|
|
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
|
|
|
weights_baseline = model_baseline.construct_weights(
|
|
"context_baseline_{}".format(frac)
|
|
)
|
|
|
|
X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
|
|
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
|
|
|
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
|
|
loss_sq = tf.reduce_mean(tf.square(X_out - X_label))
|
|
|
|
optimizer = AdamOptimizer(1e-3)
|
|
gvs = optimizer.compute_gradients(loss_sq)
|
|
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
|
train_op = optimizer.apply_gradients(gvs)
|
|
|
|
dataloader = DataLoader(
|
|
DSpritesGen(data, latents, frac=frac),
|
|
batch_size=FLAGS.batch_size,
|
|
num_workers=6,
|
|
drop_last=True,
|
|
shuffle=True,
|
|
)
|
|
|
|
datafull = data
|
|
|
|
itr = 0
|
|
saver = tf.train.Saver()
|
|
|
|
optimizer.variables()
|
|
sess.run(tf.global_variables_initializer())
|
|
|
|
if FLAGS.train:
|
|
for _ in range(5):
|
|
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
|
|
|
data_corrupt = data_corrupt.numpy()
|
|
label_size, label_pos = label_size.numpy(), label_pos.numpy()
|
|
|
|
data_corrupt = np.random.randn(
|
|
data_corrupt.shape[0], 2 * FLAGS.num_filters
|
|
)
|
|
label_comb = np.concatenate([label_size, label_pos], axis=1)
|
|
|
|
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
|
|
|
|
output = [loss_sq, train_op]
|
|
|
|
loss, _ = sess.run(output, feed_dict=feed_dict)
|
|
|
|
itr += 1
|
|
|
|
saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
|
|
|
saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
|
|
|
l = latents
|
|
|
|
if FLAGS.joint_shape:
|
|
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
|
else:
|
|
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
|
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
|
)
|
|
|
|
data_gen = datafull[mask_gen]
|
|
latents_gen = latents[mask_gen]
|
|
losses = []
|
|
|
|
for dat, latent in zip(
|
|
np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
|
|
):
|
|
data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
|
|
|
|
if FLAGS.joint_shape:
|
|
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
|
|
latent_pos = latent[:, 4:6]
|
|
latent = np.concatenate([latent_size, latent_pos], axis=1)
|
|
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
|
|
else:
|
|
feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
|
|
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
|
|
# print(loss)
|
|
losses.append(loss)
|
|
|
|
print(
|
|
"Overall MSE for generalization of {} for fraction of {}".format(
|
|
np.mean(losses), frac
|
|
)
|
|
)
|
|
|
|
data_try = data_gen[:10]
|
|
data_init = np.random.randn(10, 2 * FLAGS.num_filters)
|
|
|
|
if FLAGS.joint_shape:
|
|
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
|
|
latent_pos = latents_gen[:10, 4:]
|
|
else:
|
|
latent_scale = latents_gen[:10, 2:3]
|
|
latent_pos = latents_gen[:10, 4:]
|
|
|
|
latent_tot = np.concatenate([latent_scale, latent_pos], axis=1)
|
|
|
|
feed_dict = {X_feed: data_init, LABEL: latent_tot}
|
|
x_output = sess.run([X_out], feed_dict=feed_dict)[0]
|
|
x_output = np.clip(x_output, 0, 1)
|
|
|
|
im_name = "size_scale_combine_genbaseline.png"
|
|
|
|
x_output_wrap = np.ones((10, 66, 66))
|
|
data_try_wrap = np.ones((10, 66, 66))
|
|
|
|
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
|
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
|
|
|
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
|
-1, 66 * 2
|
|
)
|
|
impath = osp.join(save_exp_dir, im_name)
|
|
imsave(impath, im_output)
|
|
print("Successfully saved images at {}".format(impath))
|
|
|
|
return np.mean(losses)
|
|
|
|
|
|
def gentest(sess, kvs, data, latents, save_exp_dir):
|
|
X_NOISE = kvs["X_NOISE"]
|
|
LABEL_SIZE = kvs["LABEL_SIZE"]
|
|
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
|
LABEL_POS = kvs["LABEL_POS"]
|
|
LABEL_ROT = kvs["LABEL_ROT"]
|
|
model_size = kvs["model_size"]
|
|
model_shape = kvs["model_shape"]
|
|
model_pos = kvs["model_pos"]
|
|
model_rot = kvs["model_rot"]
|
|
weight_size = kvs["weight_size"]
|
|
weight_shape = kvs["weight_shape"]
|
|
weight_pos = kvs["weight_pos"]
|
|
weight_rot = kvs["weight_rot"]
|
|
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
|
|
|
datafull = data
|
|
# Test combination of generalization where we use slices of both training
|
|
x_final = X_NOISE
|
|
x_mod_pos = X_NOISE
|
|
|
|
for i in range(FLAGS.num_steps):
|
|
|
|
# use cond_pos
|
|
|
|
x_mod_pos = x_mod_pos + tf.random_normal(
|
|
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
|
)
|
|
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
|
|
|
|
# energies.append(e_noise)
|
|
x_grad = tf.gradients(e_noise, [x_final])[0]
|
|
x_mod_pos = x_mod_pos + tf.random_normal(
|
|
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
|
)
|
|
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
|
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
|
|
|
if FLAGS.joint_shape:
|
|
# use cond_shape
|
|
e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE)
|
|
elif FLAGS.joint_rot:
|
|
e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT)
|
|
else:
|
|
# use cond_size
|
|
e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE)
|
|
|
|
# energies.append(e_noise)
|
|
# energy_stack = tf.concat(energies, axis=1)
|
|
# energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1)
|
|
# energy_stack = tf.reduce_sum(energy_stack, axis=1)
|
|
|
|
x_grad = tf.gradients(e_noise, [x_mod_pos])[0]
|
|
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
|
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
|
|
|
# for x_mod_size
|
|
# use cond_size
|
|
# e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE)
|
|
# x_grad = tf.gradients(e_noise, [x_mod_size])[0]
|
|
# x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
|
|
# x_mod_size = x_mod_size - FLAGS.step_lr * x_grad
|
|
# x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
|
|
|
|
# # use cond_pos
|
|
# e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS)
|
|
# x_grad = tf.gradients(e_noise, [x_mod_size])[0]
|
|
# x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
|
|
# x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad)
|
|
# x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
|
|
|
|
x_mod = x_mod_pos
|
|
x_final = x_mod
|
|
|
|
if FLAGS.joint_shape:
|
|
loss_kl = model_shape.forward(
|
|
x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
|
|
) + model_pos.forward(
|
|
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
|
)
|
|
|
|
energy_pos = model_shape.forward(
|
|
X, weight_shape, reuse=True, label=LABEL_SHAPE
|
|
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
|
|
|
energy_neg = model_shape.forward(
|
|
tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
|
|
) + model_pos.forward(
|
|
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
|
)
|
|
elif FLAGS.joint_rot:
|
|
loss_kl = model_rot.forward(
|
|
x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
|
|
) + model_pos.forward(
|
|
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
|
)
|
|
|
|
energy_pos = model_rot.forward(
|
|
X, weight_rot, reuse=True, label=LABEL_ROT
|
|
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
|
|
|
energy_neg = model_rot.forward(
|
|
tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
|
|
) + model_pos.forward(
|
|
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
|
)
|
|
else:
|
|
loss_kl = model_size.forward(
|
|
x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
|
|
) + model_pos.forward(
|
|
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
|
)
|
|
|
|
energy_pos = model_size.forward(
|
|
X, weight_size, reuse=True, label=LABEL_SIZE
|
|
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
|
|
|
energy_neg = model_size.forward(
|
|
tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
|
|
) + model_pos.forward(
|
|
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
|
)
|
|
|
|
energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
|
|
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
|
|
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
|
coeff * (-1 * energy_neg) / norm_constant
|
|
|
|
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
|
|
loss_total = (
|
|
loss_ml
|
|
+ tf.reduce_mean(loss_kl)
|
|
+ 1
|
|
* (
|
|
tf.reduce_mean(tf.square(energy_pos))
|
|
+ tf.reduce_mean(tf.square(energy_neg))
|
|
)
|
|
)
|
|
|
|
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
|
|
gvs = optimizer.compute_gradients(loss_total)
|
|
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
|
train_op = optimizer.apply_gradients(gvs)
|
|
|
|
vs = optimizer.variables()
|
|
sess.run(tf.variables_initializer(vs))
|
|
|
|
dataloader = DataLoader(
|
|
DSpritesGen(data, latents),
|
|
batch_size=FLAGS.batch_size,
|
|
num_workers=6,
|
|
drop_last=True,
|
|
shuffle=True,
|
|
)
|
|
|
|
x_off = tf.reduce_mean(tf.square(x_mod - X))
|
|
|
|
itr = 0
|
|
saver = tf.train.Saver()
|
|
x_mod = None
|
|
|
|
if FLAGS.train:
|
|
replay_buffer = ReplayBuffer(10000)
|
|
for _ in range(1):
|
|
|
|
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
|
data_corrupt = data_corrupt.numpy()[:, :, :]
|
|
data = data.numpy()[:, :, :]
|
|
|
|
if x_mod is not None:
|
|
replay_buffer.add(x_mod)
|
|
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
|
replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
|
|
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
|
|
|
if FLAGS.joint_shape:
|
|
feed_dict = {
|
|
X_NOISE: data_corrupt,
|
|
X: data,
|
|
LABEL_SHAPE: label_size,
|
|
LABEL_POS: label_pos,
|
|
}
|
|
elif FLAGS.joint_rot:
|
|
feed_dict = {
|
|
X_NOISE: data_corrupt,
|
|
X: data,
|
|
LABEL_ROT: label_size,
|
|
LABEL_POS: label_pos,
|
|
}
|
|
else:
|
|
feed_dict = {
|
|
X_NOISE: data_corrupt,
|
|
X: data,
|
|
LABEL_SIZE: label_size,
|
|
LABEL_POS: label_pos,
|
|
}
|
|
|
|
_, off_value, e_pos, e_neg, x_mod = sess.run(
|
|
[train_op, x_off, energy_pos, energy_neg, x_final],
|
|
feed_dict=feed_dict,
|
|
)
|
|
itr += 1
|
|
|
|
if itr % 10 == 0:
|
|
print(
|
|
"x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
|
|
off_value, e_pos.mean(), e_neg.mean(), itr
|
|
)
|
|
)
|
|
|
|
if itr == FLAGS.break_steps:
|
|
break
|
|
|
|
saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
|
|
|
|
saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
|
|
|
|
l = latents
|
|
|
|
if FLAGS.joint_shape:
|
|
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
|
elif FLAGS.joint_rot:
|
|
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
|
|
else:
|
|
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
|
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
|
)
|
|
|
|
data_gen = datafull[mask_gen]
|
|
latents_gen = latents[mask_gen]
|
|
|
|
losses = []
|
|
|
|
for dat, latent in zip(
|
|
np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
|
|
):
|
|
x = 0.5 + np.random.randn(*dat.shape)
|
|
|
|
if FLAGS.joint_shape:
|
|
feed_dict = {
|
|
LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
|
|
LABEL_POS: latent[:, 4:],
|
|
X_NOISE: x,
|
|
X: dat,
|
|
}
|
|
elif FLAGS.joint_rot:
|
|
feed_dict = {
|
|
LABEL_ROT: np.concatenate(
|
|
[np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
|
|
),
|
|
LABEL_POS: latent[:, 4:],
|
|
X_NOISE: x,
|
|
X: dat,
|
|
}
|
|
else:
|
|
feed_dict = {
|
|
LABEL_SIZE: latent[:, 2:3],
|
|
LABEL_POS: latent[:, 4:],
|
|
X_NOISE: x,
|
|
X: dat,
|
|
}
|
|
|
|
for i in range(2):
|
|
x = sess.run([x_final], feed_dict=feed_dict)[0]
|
|
feed_dict[X_NOISE] = x
|
|
|
|
loss = sess.run([x_off], feed_dict=feed_dict)[0]
|
|
losses.append(loss)
|
|
|
|
print("Mean MSE loss of {} ".format(np.mean(losses)))
|
|
|
|
data_try = data_gen[:10]
|
|
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
|
latent_scale = latents_gen[:10, 2:3]
|
|
latent_pos = latents_gen[:10, 4:]
|
|
|
|
if FLAGS.joint_shape:
|
|
feed_dict = {
|
|
X_NOISE: data_init,
|
|
LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
|
|
LABEL_POS: latent_pos,
|
|
}
|
|
elif FLAGS.joint_rot:
|
|
feed_dict = {
|
|
LABEL_ROT: np.concatenate(
|
|
[np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
|
|
),
|
|
LABEL_POS: latent[:10, 4:],
|
|
X_NOISE: data_init,
|
|
}
|
|
else:
|
|
feed_dict = {
|
|
X_NOISE: data_init,
|
|
LABEL_SIZE: latent_scale,
|
|
LABEL_POS: latent_pos,
|
|
}
|
|
|
|
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
|
|
|
|
if FLAGS.joint_shape:
|
|
im_name = "size_shape_combine_gentest.png"
|
|
else:
|
|
im_name = "size_scale_combine_gentest.png"
|
|
|
|
x_output_wrap = np.ones((10, 66, 66))
|
|
data_try_wrap = np.ones((10, 66, 66))
|
|
|
|
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
|
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
|
|
|
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
|
-1, 66 * 2
|
|
)
|
|
impath = osp.join(save_exp_dir, im_name)
|
|
imsave(impath, im_output)
|
|
print("Successfully saved images at {}".format(impath))
|
|
|
|
|
|
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|
X_NOISE = kvs["X_NOISE"]
|
|
LABEL_SIZE = kvs["LABEL_SIZE"]
|
|
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
|
LABEL_POS = kvs["LABEL_POS"]
|
|
LABEL_ROT = kvs["LABEL_ROT"]
|
|
model_size = kvs["model_size"]
|
|
model_shape = kvs["model_shape"]
|
|
model_pos = kvs["model_pos"]
|
|
model_rot = kvs["model_rot"]
|
|
weight_size = kvs["weight_size"]
|
|
weight_shape = kvs["weight_shape"]
|
|
weight_pos = kvs["weight_pos"]
|
|
weight_rot = kvs["weight_rot"]
|
|
|
|
x_mod = X_NOISE
|
|
for i in range(FLAGS.num_steps):
|
|
|
|
if FLAGS.cond_scale:
|
|
e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE)
|
|
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
|
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
|
x_mod = x_mod - FLAGS.step_lr * x_grad
|
|
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
|
|
|
if FLAGS.cond_shape:
|
|
e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE)
|
|
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
|
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
|
x_mod = x_mod - FLAGS.step_lr * x_grad
|
|
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
|
|
|
if FLAGS.cond_pos:
|
|
e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS)
|
|
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
|
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
|
x_mod = x_mod - FLAGS.step_lr * x_grad
|
|
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
|
|
|
if FLAGS.cond_rot:
|
|
e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT)
|
|
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
|
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
|
x_mod = x_mod - FLAGS.step_lr * x_grad
|
|
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
|
|
|
print("Finished constructing loop {}".format(i))
|
|
|
|
x_final = x_mod
|
|
|
|
data_try = data[:10]
|
|
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
|
label_scale = latents[:10, 2:3]
|
|
label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
|
|
label_rot = latents[:10, 3:4]
|
|
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
|
|
label_pos = latents[:10, 4:]
|
|
|
|
feed_dict = {
|
|
X_NOISE: data_init,
|
|
LABEL_SIZE: label_scale,
|
|
LABEL_SHAPE: label_shape,
|
|
LABEL_POS: label_pos,
|
|
LABEL_ROT: label_rot,
|
|
}
|
|
x_out = sess.run([x_final], feed_dict)[0]
|
|
|
|
im_name = "im"
|
|
|
|
if FLAGS.cond_scale:
|
|
im_name += "_condscale"
|
|
|
|
if FLAGS.cond_shape:
|
|
im_name += "_condshape"
|
|
|
|
if FLAGS.cond_pos:
|
|
im_name += "_condpos"
|
|
|
|
if FLAGS.cond_rot:
|
|
im_name += "_condrot"
|
|
|
|
im_name += ".png"
|
|
|
|
x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66))
|
|
x_out_pad[:, 1:-1, 1:-1] = x_out
|
|
data_try_pad[:, 1:-1, 1:-1] = data_try
|
|
|
|
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
|
|
impath = osp.join(save_exp_dir, im_name)
|
|
imsave(impath, im_output)
|
|
print("Successfully saved images at {}".format(impath))
|
|
|
|
|
|
def main():
|
|
data = np.load(FLAGS.dsprites_path)["imgs"]
|
|
l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
|
|
|
|
np.random.seed(1)
|
|
idx = np.random.permutation(data.shape[0])
|
|
|
|
data = data[idx]
|
|
latents = latents[idx]
|
|
|
|
config = tf.ConfigProto()
|
|
sess = tf.Session(config=config)
|
|
|
|
# Model 1 will be conditioned on size
|
|
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
|
|
weight_size = model_size.construct_weights("context_0")
|
|
|
|
# Model 2 will be conditioned on shape
|
|
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
|
|
weight_shape = model_shape.construct_weights("context_1")
|
|
|
|
# Model 3 will be conditioned on position
|
|
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
|
|
weight_pos = model_pos.construct_weights("context_2")
|
|
|
|
# Model 4 will be conditioned on rotation
|
|
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
|
|
weight_rot = model_rot.construct_weights("context_3")
|
|
|
|
sess.run(tf.global_variables_initializer())
|
|
save_path_size = osp.join(
|
|
FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
|
|
)
|
|
|
|
v_list = tf.get_collection(
|
|
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
|
|
)
|
|
v_map = {
|
|
(v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
|
|
}
|
|
|
|
if FLAGS.cond_scale:
|
|
saver = tf.train.Saver(v_map)
|
|
saver.restore(sess, save_path_size)
|
|
|
|
save_path_shape = osp.join(
|
|
FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
|
|
)
|
|
|
|
v_list = tf.get_collection(
|
|
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
|
|
)
|
|
v_map = {
|
|
(v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
|
|
}
|
|
|
|
if FLAGS.cond_shape:
|
|
saver = tf.train.Saver(v_map)
|
|
saver.restore(sess, save_path_shape)
|
|
|
|
save_path_pos = osp.join(
|
|
FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
|
|
)
|
|
v_list = tf.get_collection(
|
|
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
|
|
)
|
|
v_map = {
|
|
(v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
|
|
}
|
|
saver = tf.train.Saver(v_map)
|
|
|
|
if FLAGS.cond_pos:
|
|
saver.restore(sess, save_path_pos)
|
|
|
|
save_path_rot = osp.join(
|
|
FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
|
|
)
|
|
v_list = tf.get_collection(
|
|
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
|
|
)
|
|
v_map = {
|
|
(v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
|
|
}
|
|
saver = tf.train.Saver(v_map)
|
|
|
|
if FLAGS.cond_rot:
|
|
saver.restore(sess, save_path_rot)
|
|
|
|
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
|
LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
|
LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
|
|
kvs = {}
|
|
kvs["X_NOISE"] = X_NOISE
|
|
kvs["LABEL_SIZE"] = LABEL_SIZE
|
|
kvs["LABEL_SHAPE"] = LABEL_SHAPE
|
|
kvs["LABEL_POS"] = LABEL_POS
|
|
kvs["LABEL_ROT"] = LABEL_ROT
|
|
kvs["model_size"] = model_size
|
|
kvs["model_shape"] = model_shape
|
|
kvs["model_pos"] = model_pos
|
|
kvs["model_rot"] = model_rot
|
|
kvs["weight_size"] = weight_size
|
|
kvs["weight_shape"] = weight_shape
|
|
kvs["weight_pos"] = weight_pos
|
|
kvs["weight_rot"] = weight_rot
|
|
|
|
save_exp_dir = osp.join(
|
|
FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
|
|
)
|
|
if not osp.exists(save_exp_dir):
|
|
os.makedirs(save_exp_dir)
|
|
|
|
if FLAGS.task == "conceptcombine":
|
|
conceptcombine(sess, kvs, data, latents, save_exp_dir)
|
|
elif FLAGS.task == "labeldiscover":
|
|
labeldiscover(sess, kvs, data, latents, save_exp_dir)
|
|
elif FLAGS.task == "gentest":
|
|
save_exp_dir = osp.join(
|
|
FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
|
|
)
|
|
if not osp.exists(save_exp_dir):
|
|
os.makedirs(save_exp_dir)
|
|
|
|
gentest(sess, kvs, data, latents, save_exp_dir)
|
|
elif FLAGS.task == "genbaseline":
|
|
save_exp_dir = osp.join(
|
|
FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
|
|
)
|
|
if not osp.exists(save_exp_dir):
|
|
os.makedirs(save_exp_dir)
|
|
|
|
if FLAGS.plot_curve:
|
|
mse_losses = []
|
|
for frac in [i / 10 for i in range(11)]:
|
|
mse_loss = genbaseline(
|
|
sess, kvs, data, latents, save_exp_dir, frac=frac
|
|
)
|
|
mse_losses.append(mse_loss)
|
|
np.save("mse_baseline_comb.npy", mse_losses)
|
|
else:
|
|
genbaseline(sess, kvs, data, latents, save_exp_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|