mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-11-23 16:10:49 -05:00
chores: refactor for the new ai research, add linter, gh action, etc (#27)
This commit is contained in:
parent
fb4ab80dc3
commit
d5467e559f
40 changed files with 5177 additions and 2476 deletions
|
|
@ -1,68 +1,100 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from hmc import hmc
|
||||
from custom_adam import AdamOptimizer
|
||||
from models import DspritesNet
|
||||
from scipy.misc import imsave
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from models import DspritesNet
|
||||
from utils import optimistic_restore, ReplayBuffer
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from rl_algs.logger import TensorBoardOutputFormat
|
||||
from scipy.misc import imsave
|
||||
import os
|
||||
from custom_adam import AdamOptimizer
|
||||
from tqdm import tqdm
|
||||
from utils import ReplayBuffer
|
||||
|
||||
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
|
||||
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
|
||||
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
|
||||
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
|
||||
flags.DEFINE_bool('cclass', True, 'not cclass')
|
||||
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
|
||||
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_integer("batch_size", 256, "Size of inputs")
|
||||
flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
|
||||
flags.DEFINE_string("logdir", "cachedir", "directory for logging")
|
||||
flags.DEFINE_string(
|
||||
"savedir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
"num_filters",
|
||||
64,
|
||||
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
||||
)
|
||||
flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
|
||||
flags.DEFINE_string(
|
||||
"dsprites_path",
|
||||
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
||||
"path to dsprites characters",
|
||||
)
|
||||
flags.DEFINE_bool("cclass", True, "not cclass")
|
||||
flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
|
||||
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
||||
flags.DEFINE_string(
|
||||
"task",
|
||||
"conceptcombine",
|
||||
"conceptcombine, labeldiscover, gentest, genbaseline, etc.",
|
||||
)
|
||||
flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
|
||||
flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
|
||||
|
||||
# Conditions on which models to use
|
||||
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
|
||||
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
|
||||
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
|
||||
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
|
||||
flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
|
||||
flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
|
||||
flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
|
||||
flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
|
||||
|
||||
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
|
||||
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
|
||||
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
|
||||
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
|
||||
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('break_steps', 300, 'steps to break')
|
||||
flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
|
||||
flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
|
||||
flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
|
||||
flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
|
||||
flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
|
||||
flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("break_steps", 300, "steps to break")
|
||||
|
||||
# Whether to train for gentest
|
||||
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
|
||||
flags.DEFINE_bool(
|
||||
"train",
|
||||
False,
|
||||
"whether to train on generalization into multiple different predictions",
|
||||
)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class DSpritesGen(Dataset):
|
||||
def __init__(self, data, latents, frac=0.0):
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
mask_size = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
)
|
||||
elif FLAGS.joint_rot:
|
||||
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
mask_size = (
|
||||
(l[:, 1] == 1)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
)
|
||||
else:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
|
||||
mask_size = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
|
||||
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
|
||||
|
|
@ -80,12 +112,14 @@ class DSpritesGen(Dataset):
|
|||
self.data = np.concatenate((data_pos, data_size), axis=0)
|
||||
self.label = np.concatenate((l_pos, l_size), axis=0)
|
||||
|
||||
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
|
||||
mask_neg = (~(mask_size & mask_pos)) & (
|
||||
(l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
|
||||
)
|
||||
data_add = data[mask_neg]
|
||||
l_add = l[mask_neg]
|
||||
|
||||
perm_idx = np.random.permutation(data_add.shape[0])
|
||||
select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
|
||||
select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
|
||||
data_add = data_add[select_idx]
|
||||
l_add = l_add[select_idx]
|
||||
|
||||
|
|
@ -104,7 +138,9 @@ class DSpritesGen(Dataset):
|
|||
if FLAGS.joint_shape:
|
||||
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
|
||||
elif FLAGS.joint_rot:
|
||||
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
|
||||
label_size = np.array(
|
||||
[np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
|
||||
)
|
||||
else:
|
||||
label_size = self.label[index, 2:3]
|
||||
|
||||
|
|
@ -114,14 +150,16 @@ class DSpritesGen(Dataset):
|
|||
|
||||
|
||||
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
model_size = kvs['model_size']
|
||||
weight_size = kvs['weight_size']
|
||||
x_mod = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
model_size = kvs["model_size"]
|
||||
weight_size = kvs["weight_size"]
|
||||
x_mod = kvs["X_NOISE"]
|
||||
|
||||
label_output = LABEL_SIZE
|
||||
for i in range(FLAGS.num_steps):
|
||||
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
|
||||
label_output = label_output + tf.random_normal(
|
||||
tf.shape(label_output), mean=0.0, stddev=0.03
|
||||
)
|
||||
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
|
||||
label_grad = tf.gradients(e_noise, [label_output])[0]
|
||||
# label_grad = tf.Print(label_grad, [label_grad])
|
||||
|
|
@ -130,13 +168,13 @@ def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
|||
|
||||
diffs = []
|
||||
for i in range(30):
|
||||
s = i*FLAGS.batch_size
|
||||
d = (i+1)*FLAGS.batch_size
|
||||
s = i * FLAGS.batch_size
|
||||
d = (i + 1) * FLAGS.batch_size
|
||||
data_i = data[s:d]
|
||||
latent_i = latents[s:d]
|
||||
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
|
||||
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
|
||||
size_pred = sess.run([label_output], feed_dict)[0]
|
||||
size_gt = latent_i[:, 2:3]
|
||||
|
||||
|
|
@ -155,9 +193,11 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
|
||||
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
|
||||
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
|
||||
weights_baseline = model_baseline.construct_weights(
|
||||
"context_baseline_{}".format(frac)
|
||||
)
|
||||
|
||||
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
|
||||
X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
|
||||
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
|
||||
|
|
@ -168,14 +208,20 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
||||
train_op = optimizer.apply_gradients(gvs)
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
dataloader = DataLoader(
|
||||
DSpritesGen(data, latents, frac=frac),
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=6,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
datafull = data
|
||||
|
||||
itr = 0
|
||||
saver = tf.train.Saver()
|
||||
|
||||
vs = optimizer.variables()
|
||||
optimizer.variables()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.train:
|
||||
|
|
@ -185,7 +231,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
data_corrupt = data_corrupt.numpy()
|
||||
label_size, label_pos = label_size.numpy(), label_pos.numpy()
|
||||
|
||||
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
|
||||
data_corrupt = np.random.randn(
|
||||
data_corrupt.shape[0], 2 * FLAGS.num_filters
|
||||
)
|
||||
label_comb = np.concatenate([label_size, label_pos], axis=1)
|
||||
|
||||
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
|
||||
|
|
@ -196,23 +244,27 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
|
||||
itr += 1
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
||||
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
||||
)
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
|
||||
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
|
||||
for dat, latent in zip(
|
||||
np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
|
||||
):
|
||||
data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
|
||||
|
|
@ -220,16 +272,19 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
latent = np.concatenate([latent_size, latent_pos], axis=1)
|
||||
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
|
||||
else:
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
|
||||
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
|
||||
# print(loss)
|
||||
losses.append(loss)
|
||||
|
||||
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
|
||||
|
||||
print(
|
||||
"Overall MSE for generalization of {} for fraction of {}".format(
|
||||
np.mean(losses), frac
|
||||
)
|
||||
)
|
||||
|
||||
data_try = data_gen[:10]
|
||||
data_init = np.random.randn(10, 2*FLAGS.num_filters)
|
||||
data_init = np.random.randn(10, 2 * FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
|
||||
|
|
@ -252,7 +307,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
||||
-1, 66 * 2
|
||||
)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
|
@ -261,38 +318,40 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
|
||||
|
||||
def gentest(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X_NOISE = kvs["X_NOISE"]
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
||||
LABEL_POS = kvs["LABEL_POS"]
|
||||
LABEL_ROT = kvs["LABEL_ROT"]
|
||||
model_size = kvs["model_size"]
|
||||
model_shape = kvs["model_shape"]
|
||||
model_pos = kvs["model_pos"]
|
||||
model_rot = kvs["model_rot"]
|
||||
weight_size = kvs["weight_size"]
|
||||
weight_shape = kvs["weight_shape"]
|
||||
weight_pos = kvs["weight_pos"]
|
||||
weight_rot = kvs["weight_rot"]
|
||||
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
datafull = data
|
||||
# Test combination of generalization where we use slices of both training
|
||||
x_final = X_NOISE
|
||||
x_mod_size = X_NOISE
|
||||
x_mod_pos = X_NOISE
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
||||
# use cond_pos
|
||||
|
||||
energies = []
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(
|
||||
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
||||
)
|
||||
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
|
||||
|
||||
# energies.append(e_noise)
|
||||
x_grad = tf.gradients(e_noise, [x_final])[0]
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(
|
||||
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
||||
)
|
||||
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
||||
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
||||
|
||||
|
|
@ -332,42 +391,70 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
x_mod = x_mod_pos
|
||||
x_final = x_mod
|
||||
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_shape.forward(
|
||||
x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_shape.forward(
|
||||
X, weight_shape, reuse=True, label=LABEL_SHAPE
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_shape.forward(
|
||||
tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
elif FLAGS.joint_rot:
|
||||
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_rot.forward(
|
||||
x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_rot.forward(
|
||||
X, weight_rot, reuse=True, label=LABEL_ROT
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_rot.forward(
|
||||
tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
else:
|
||||
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_size.forward(
|
||||
x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_size.forward(
|
||||
X, weight_size, reuse=True, label=LABEL_SIZE
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_size.forward(
|
||||
tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
|
||||
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
|
||||
energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
|
||||
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
|
||||
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
||||
neg_loss = coeff * (-1*energy_neg) / norm_constant
|
||||
coeff * (-1 * energy_neg) / norm_constant
|
||||
|
||||
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
|
||||
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
|
||||
loss_total = (
|
||||
loss_ml
|
||||
+ tf.reduce_mean(loss_kl)
|
||||
+ 1
|
||||
* (
|
||||
tf.reduce_mean(tf.square(energy_pos))
|
||||
+ tf.reduce_mean(tf.square(energy_neg))
|
||||
)
|
||||
)
|
||||
|
||||
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
|
||||
gvs = optimizer.compute_gradients(loss_total)
|
||||
|
|
@ -377,7 +464,13 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
vs = optimizer.variables()
|
||||
sess.run(tf.variables_initializer(vs))
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
dataloader = DataLoader(
|
||||
DSpritesGen(data, latents),
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=6,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
x_off = tf.reduce_mean(tf.square(x_mod - X))
|
||||
|
||||
|
|
@ -385,12 +478,10 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
saver = tf.train.Saver()
|
||||
x_mod = None
|
||||
|
||||
|
||||
if FLAGS.train:
|
||||
replay_buffer = ReplayBuffer(10000)
|
||||
for _ in range(1):
|
||||
|
||||
|
||||
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
||||
data_corrupt = data_corrupt.numpy()[:, :, :]
|
||||
data = data.numpy()[:, :, :]
|
||||
|
|
@ -398,29 +489,50 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
if x_mod is not None:
|
||||
replay_buffer.add(x_mod)
|
||||
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
||||
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
|
||||
replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
|
||||
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_SHAPE: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_ROT: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_SIZE: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run(
|
||||
[train_op, x_off, energy_pos, energy_neg, x_final],
|
||||
feed_dict=feed_dict,
|
||||
)
|
||||
itr += 1
|
||||
|
||||
if itr % 10 == 0:
|
||||
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
|
||||
print(
|
||||
"x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
|
||||
off_value, e_pos.mean(), e_neg.mean(), itr
|
||||
)
|
||||
)
|
||||
|
||||
if itr == FLAGS.break_steps:
|
||||
break
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
|
||||
|
||||
l = latents
|
||||
|
||||
|
|
@ -429,22 +541,43 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
elif FLAGS.joint_rot:
|
||||
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
||||
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
||||
)
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
|
||||
for dat, latent in zip(
|
||||
np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
|
||||
):
|
||||
x = 0.5 + np.random.randn(*dat.shape)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_ROT: np.concatenate(
|
||||
[np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
|
||||
),
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
else:
|
||||
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_SIZE: latent[:, 2:3],
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
|
||||
for i in range(2):
|
||||
x = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
|
@ -461,11 +594,25 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
latent_pos = latents_gen[:10, 4:]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
|
||||
LABEL_POS: latent_pos,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
|
||||
feed_dict = {
|
||||
LABEL_ROT: np.concatenate(
|
||||
[np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
|
||||
),
|
||||
LABEL_POS: latent[:10, 4:],
|
||||
X_NOISE: data_init,
|
||||
}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SIZE: latent_scale,
|
||||
LABEL_POS: latent_pos,
|
||||
}
|
||||
|
||||
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
||||
|
|
@ -480,27 +627,28 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
||||
-1, 66 * 2
|
||||
)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
|
||||
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X_NOISE = kvs["X_NOISE"]
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
||||
LABEL_POS = kvs["LABEL_POS"]
|
||||
LABEL_ROT = kvs["LABEL_ROT"]
|
||||
model_size = kvs["model_size"]
|
||||
model_shape = kvs["model_shape"]
|
||||
model_pos = kvs["model_pos"]
|
||||
model_rot = kvs["model_rot"]
|
||||
weight_size = kvs["weight_size"]
|
||||
weight_shape = kvs["weight_shape"]
|
||||
weight_pos = kvs["weight_pos"]
|
||||
weight_rot = kvs["weight_rot"]
|
||||
|
||||
x_mod = X_NOISE
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
|
@ -540,13 +688,18 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|||
data_try = data[:10]
|
||||
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
||||
label_scale = latents[:10, 2:3]
|
||||
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
|
||||
label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
|
||||
label_rot = latents[:10, 3:4]
|
||||
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
|
||||
label_pos = latents[:10, 4:]
|
||||
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SIZE: label_scale,
|
||||
LABEL_SHAPE: label_shape,
|
||||
LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot,
|
||||
}
|
||||
x_out = sess.run([x_final], feed_dict)[0]
|
||||
|
||||
im_name = "im"
|
||||
|
|
@ -569,14 +722,15 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|||
x_out_pad[:, 1:-1, 1:-1] = x_out
|
||||
data_try_pad[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
def main():
|
||||
data = np.load(FLAGS.dsprites_path)['imgs']
|
||||
l = latents = np.load(FLAGS.dsprites_path)['latents_values']
|
||||
data = np.load(FLAGS.dsprites_path)["imgs"]
|
||||
l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
|
||||
|
||||
np.random.seed(1)
|
||||
idx = np.random.permutation(data.shape[0])
|
||||
|
|
@ -589,52 +743,74 @@ def main():
|
|||
|
||||
# Model 1 will be conditioned on size
|
||||
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
|
||||
weight_size = model_size.construct_weights('context_0')
|
||||
weight_size = model_size.construct_weights("context_0")
|
||||
|
||||
# Model 2 will be conditioned on shape
|
||||
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
|
||||
weight_shape = model_shape.construct_weights('context_1')
|
||||
weight_shape = model_shape.construct_weights("context_1")
|
||||
|
||||
# Model 3 will be conditioned on position
|
||||
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
|
||||
weight_pos = model_pos.construct_weights('context_2')
|
||||
weight_pos = model_pos.construct_weights("context_2")
|
||||
|
||||
# Model 4 will be conditioned on rotation
|
||||
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
|
||||
weight_rot = model_rot.construct_weights('context_3')
|
||||
weight_rot = model_rot.construct_weights("context_3")
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
|
||||
save_path_size = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
|
||||
)
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
|
||||
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
|
||||
if FLAGS.cond_scale:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_size)
|
||||
|
||||
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
|
||||
save_path_shape = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
|
||||
)
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
|
||||
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
|
||||
if FLAGS.cond_shape:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_shape)
|
||||
|
||||
|
||||
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
|
||||
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
|
||||
save_path_pos = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
|
||||
)
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_pos:
|
||||
saver.restore(sess, save_path_pos)
|
||||
|
||||
|
||||
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
|
||||
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
|
||||
save_path_rot = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
|
||||
)
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_rot:
|
||||
|
|
@ -646,53 +822,57 @@ def main():
|
|||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
|
||||
x_mod = X_NOISE
|
||||
|
||||
kvs = {}
|
||||
kvs['X_NOISE'] = X_NOISE
|
||||
kvs['LABEL_SIZE'] = LABEL_SIZE
|
||||
kvs['LABEL_SHAPE'] = LABEL_SHAPE
|
||||
kvs['LABEL_POS'] = LABEL_POS
|
||||
kvs['LABEL_ROT'] = LABEL_ROT
|
||||
kvs['model_size'] = model_size
|
||||
kvs['model_shape'] = model_shape
|
||||
kvs['model_pos'] = model_pos
|
||||
kvs['model_rot'] = model_rot
|
||||
kvs['weight_size'] = weight_size
|
||||
kvs['weight_shape'] = weight_shape
|
||||
kvs['weight_pos'] = weight_pos
|
||||
kvs['weight_rot'] = weight_rot
|
||||
kvs["X_NOISE"] = X_NOISE
|
||||
kvs["LABEL_SIZE"] = LABEL_SIZE
|
||||
kvs["LABEL_SHAPE"] = LABEL_SHAPE
|
||||
kvs["LABEL_POS"] = LABEL_POS
|
||||
kvs["LABEL_ROT"] = LABEL_ROT
|
||||
kvs["model_size"] = model_size
|
||||
kvs["model_shape"] = model_shape
|
||||
kvs["model_pos"] = model_pos
|
||||
kvs["model_rot"] = model_rot
|
||||
kvs["weight_size"] = weight_size
|
||||
kvs["weight_shape"] = weight_shape
|
||||
kvs["weight_pos"] = weight_pos
|
||||
kvs["weight_rot"] = weight_rot
|
||||
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
|
||||
if FLAGS.task == 'conceptcombine':
|
||||
if FLAGS.task == "conceptcombine":
|
||||
conceptcombine(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'labeldiscover':
|
||||
elif FLAGS.task == "labeldiscover":
|
||||
labeldiscover(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'gentest':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
elif FLAGS.task == "gentest":
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
gentest(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'genbaseline':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
elif FLAGS.task == "genbaseline":
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
if FLAGS.plot_curve:
|
||||
mse_losses = []
|
||||
for frac in [i/10 for i in range(11)]:
|
||||
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
|
||||
for frac in [i / 10 for i in range(11)]:
|
||||
mse_loss = genbaseline(
|
||||
sess, kvs, data, latents, save_exp_dir, frac=frac
|
||||
)
|
||||
mse_losses.append(mse_loss)
|
||||
np.save("mse_baseline_comb.npy", mse_losses)
|
||||
else:
|
||||
genbaseline(sess, kvs, data, latents, save_exp_dir)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue