mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-16 02:00:32 -04:00
327 lines
11 KiB
Python
327 lines
11 KiB
Python
import math
|
|
import os.path as osp
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from data import Cifar10, DSprites, Mnist
|
|
from hmc import hmc
|
|
from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
|
|
from tensorflow.python.platform import flags
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from utils import optimistic_restore
|
|
|
|
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
|
flags.DEFINE_string(
|
|
"dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
|
|
)
|
|
flags.DEFINE_string(
|
|
"logdir", "cachedir", "location where log of experiments will be stored"
|
|
)
|
|
flags.DEFINE_string("exp", "default", "name of experiments")
|
|
flags.DEFINE_integer(
|
|
"data_workers", 5, "Number of different data workers to load data in parallel"
|
|
)
|
|
flags.DEFINE_integer("batch_size", 16, "Size of inputs")
|
|
flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
|
|
|
|
flags.DEFINE_bool(
|
|
"max_pool",
|
|
False,
|
|
"Whether or not to use max pooling rather than strided convolutions",
|
|
)
|
|
flags.DEFINE_integer(
|
|
"num_filters",
|
|
64,
|
|
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
|
)
|
|
flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
|
|
flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
|
|
flags.DEFINE_integer(
|
|
"rescale", 1, "factor to rescale input outside of normal (0, 1) box"
|
|
)
|
|
flags.DEFINE_float(
|
|
"temperature", 1, "temperature at which to compute likelihood of model"
|
|
)
|
|
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
|
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
|
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
|
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
|
flags.DEFINE_bool(
|
|
"cclass",
|
|
False,
|
|
"Whether to evaluate the log likelihood of conditional model or not",
|
|
)
|
|
flags.DEFINE_bool(
|
|
"single",
|
|
False,
|
|
"Whether to evaluate the log likelihood of conditional model or not",
|
|
)
|
|
flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
|
|
flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
|
|
flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
label_default = np.eye(10)[0:1, :]
|
|
label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
|
|
|
|
|
def unscale_im(im):
|
|
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
|
|
|
|
|
def gauss_prob_log(x, prec=1.0):
|
|
|
|
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
|
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
|
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
|
|
|
|
return norm_constant_log + prob_density_log
|
|
|
|
|
|
def uniform_prob_log(x):
|
|
|
|
return tf.zeros(1)
|
|
|
|
|
|
def model_prob_log(x, e_func, weights, temp):
|
|
if FLAGS.cclass:
|
|
batch_size = tf.shape(x)[0]
|
|
label_tiled = tf.tile(label_default, (batch_size, 1))
|
|
e_raw = e_func.forward(x, weights, label=label_tiled)
|
|
else:
|
|
e_raw = e_func.forward(x, weights)
|
|
energy = tf.reduce_sum(e_raw, axis=[1])
|
|
return -temp * energy
|
|
|
|
|
|
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
|
|
|
if FLAGS.dataset == "gauss":
|
|
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
|
|
x, prec=FLAGS.temperature
|
|
)
|
|
else:
|
|
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
|
|
x, e_func, weights, temp
|
|
)
|
|
# Add an additional log likelihood penalty so that points outside of (0,
|
|
# 1) box are *highly* unlikely
|
|
|
|
if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
|
oob_prob = tf.reduce_sum(
|
|
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
|
|
)
|
|
elif FLAGS.dataset == "mnist":
|
|
oob_prob = tf.reduce_sum(
|
|
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
|
|
)
|
|
else:
|
|
oob_prob = tf.reduce_sum(
|
|
tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
|
|
axis=[1, 2, 3],
|
|
)
|
|
|
|
return -norm_prob + oob_prob
|
|
|
|
|
|
def ancestral_sample(
|
|
e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
|
|
):
|
|
if FLAGS.dataset == "2d":
|
|
x = tf.placeholder(tf.float32, shape=(None, 2))
|
|
elif FLAGS.dataset == "gauss":
|
|
x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim))
|
|
elif FLAGS.dataset == "mnist":
|
|
x = tf.placeholder(tf.float32, shape=(None, 28, 28))
|
|
else:
|
|
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
|
|
|
|
x_init = x
|
|
|
|
alpha_prev = tf.placeholder(tf.float32, shape=())
|
|
alpha_new = tf.placeholder(tf.float32, shape=())
|
|
approx_lr = tf.placeholder(tf.float32, shape=())
|
|
|
|
chain_weights = tf.zeros(batch_size)
|
|
# for i in range(1, prop_dist+1):
|
|
# print("processing loop {}".format(i))
|
|
# alpha_prev = (i-1) / prop_dist
|
|
# alpha_new = i / prop_dist
|
|
|
|
prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp)
|
|
prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
|
|
|
chain_weights = -prob_log_new_neg + prob_log_old_neg
|
|
# chain_weights = tf.Print(chain_weights, [chain_weights])
|
|
|
|
# Sample new x using HMC
|
|
def unorm_prob(x):
|
|
return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
|
|
|
for j in range(1):
|
|
x = hmc(x, approx_lr, hmc_step, unorm_prob)
|
|
|
|
return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr
|
|
|
|
|
|
def main():
|
|
|
|
# Initialize dataset
|
|
if FLAGS.dataset == "cifar10":
|
|
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
|
channel_num = 3
|
|
32 * 32 * 3
|
|
elif FLAGS.dataset == "imagenet":
|
|
dataset = ImagenetClass()
|
|
channel_num = 3
|
|
64 * 64 * 3
|
|
elif FLAGS.dataset == "mnist":
|
|
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
|
channel_num = 1
|
|
28 * 28 * 1
|
|
elif FLAGS.dataset == "dsprites":
|
|
dataset = DSprites()
|
|
channel_num = 1
|
|
64 * 64 * 1
|
|
elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
|
dataset = Box2D()
|
|
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
batch_size=FLAGS.batch_size,
|
|
num_workers=FLAGS.data_workers,
|
|
drop_last=False,
|
|
shuffle=True,
|
|
)
|
|
|
|
if FLAGS.dataset == "mnist":
|
|
model = MnistNet(num_channels=channel_num)
|
|
elif FLAGS.dataset == "cifar10":
|
|
if FLAGS.large_model:
|
|
model = ResNet32Large(num_filters=128)
|
|
elif FLAGS.wider_model:
|
|
model = ResNet32Wider(num_filters=192)
|
|
else:
|
|
model = ResNet32(num_channels=channel_num, num_filters=128)
|
|
elif FLAGS.dataset == "dsprites":
|
|
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
|
|
|
weights = model.construct_weights("context_{}".format(0))
|
|
|
|
config = tf.ConfigProto()
|
|
sess = tf.Session(config=config)
|
|
saver = loader = tf.train.Saver(max_to_keep=10)
|
|
|
|
sess.run(tf.global_variables_initializer())
|
|
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
|
|
|
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
|
|
FLAGS.resume_iter
|
|
|
|
if FLAGS.resume_iter != "-1":
|
|
optimistic_restore(sess, model_file)
|
|
else:
|
|
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
|
# saver.restore(sess, model_file)
|
|
|
|
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
|
|
model, weights, FLAGS.batch_size, temp=FLAGS.temperature
|
|
)
|
|
print("Finished constructing ancestral sample ...................")
|
|
|
|
if FLAGS.dataset != "gauss":
|
|
batch_size = tf.shape(x_init)[0]
|
|
label_tiled = tf.tile(label_default, (batch_size, 1))
|
|
e_compute = -FLAGS.temperature * model.forward(
|
|
x_init, weights, label=label_tiled
|
|
)
|
|
e_pos_list = []
|
|
|
|
for data_corrupt, data, label_gt in tqdm(data_loader):
|
|
e_pos = sess.run([e_compute], {x_init: data})[0]
|
|
e_pos_list.extend(list(e_pos))
|
|
|
|
print(len(e_pos_list))
|
|
print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list))
|
|
|
|
if FLAGS.dataset == "2d":
|
|
alr = 0.0045
|
|
elif FLAGS.dataset == "gauss":
|
|
alr = 0.0085
|
|
elif FLAGS.dataset == "mnist":
|
|
alr = 0.0065
|
|
# 90 alr = 0.0035
|
|
else:
|
|
# alr = 0.0125
|
|
if FLAGS.rescale == 8:
|
|
alr = 0.0085
|
|
else:
|
|
alr = 0.0045
|
|
#
|
|
for i in range(1):
|
|
tot_weight = 0
|
|
for j in tqdm(range(1, FLAGS.pdist + 1)):
|
|
if j == 1:
|
|
if FLAGS.dataset == "cifar10":
|
|
x_curr = np.random.uniform(
|
|
0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
|
|
)
|
|
elif FLAGS.dataset == "gauss":
|
|
x_curr = np.random.uniform(
|
|
0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
|
|
)
|
|
elif FLAGS.dataset == "mnist":
|
|
x_curr = np.random.uniform(
|
|
0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
|
|
)
|
|
else:
|
|
x_curr = np.random.uniform(
|
|
0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
|
|
)
|
|
|
|
alpha_prev = (j - 1) / FLAGS.pdist
|
|
alpha_new = j / FLAGS.pdist
|
|
cweight, x_curr = sess.run(
|
|
[chain_weights, x],
|
|
{
|
|
a_prev: alpha_prev,
|
|
a_new: alpha_new,
|
|
x_init: x_curr,
|
|
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
|
},
|
|
)
|
|
tot_weight = tot_weight + cweight
|
|
|
|
print(
|
|
"Total values of lower value based off forward sampling",
|
|
np.mean(tot_weight),
|
|
np.std(tot_weight),
|
|
)
|
|
|
|
tot_weight = 0
|
|
|
|
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
|
alpha_new = (j - 1) / FLAGS.pdist
|
|
alpha_prev = j / FLAGS.pdist
|
|
cweight, x_curr = sess.run(
|
|
[chain_weights, x],
|
|
{
|
|
a_prev: alpha_prev,
|
|
a_new: alpha_new,
|
|
x_init: x_curr,
|
|
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
|
},
|
|
)
|
|
tot_weight = tot_weight - cweight
|
|
|
|
print(
|
|
"Total values of upper value based off backward sampling",
|
|
np.mean(tot_weight),
|
|
np.std(tot_weight),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|