mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-04-25 10:09:09 -04:00
250 lines
9.9 KiB
Python
250 lines
9.9 KiB
Python
import tensorflow as tf
|
|
import math
|
|
from hmc import hmc
|
|
from tensorflow.python.platform import flags
|
|
from torch.utils.data import DataLoader
|
|
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
|
|
from data import Cifar10, Mnist, DSprites
|
|
from scipy.misc import logsumexp
|
|
from scipy.misc import imsave
|
|
from utils import optimistic_restore
|
|
import os.path as osp
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
|
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
|
|
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
|
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
|
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
|
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
|
|
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
|
|
|
|
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
|
|
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
|
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
|
|
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
|
|
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
|
|
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
|
|
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
|
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
|
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
|
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
|
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
|
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
|
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
|
|
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
|
|
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
label_default = np.eye(10)[0:1, :]
|
|
label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
|
|
|
|
|
def unscale_im(im):
|
|
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
|
|
|
def gauss_prob_log(x, prec=1.0):
|
|
|
|
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
|
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
|
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
|
|
|
|
return norm_constant_log + prob_density_log
|
|
|
|
|
|
def uniform_prob_log(x):
|
|
|
|
return tf.zeros(1)
|
|
|
|
|
|
def model_prob_log(x, e_func, weights, temp):
|
|
if FLAGS.cclass:
|
|
batch_size = tf.shape(x)[0]
|
|
label_tiled = tf.tile(label_default, (batch_size, 1))
|
|
e_raw = e_func.forward(x, weights, label=label_tiled)
|
|
else:
|
|
e_raw = e_func.forward(x, weights)
|
|
energy = tf.reduce_sum(e_raw, axis=[1])
|
|
return -temp * energy
|
|
|
|
|
|
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
|
|
|
if FLAGS.dataset == "gauss":
|
|
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
|
|
else:
|
|
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
|
|
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
|
|
|
|
|
|
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
|
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
|
|
elif FLAGS.dataset == 'mnist':
|
|
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
|
|
else:
|
|
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
|
|
|
|
return -norm_prob + oob_prob
|
|
|
|
|
|
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
|
|
if FLAGS.dataset == "2d":
|
|
x = tf.placeholder(tf.float32, shape=(None, 2))
|
|
elif FLAGS.dataset == "gauss":
|
|
x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim))
|
|
elif FLAGS.dataset == "mnist":
|
|
x = tf.placeholder(tf.float32, shape=(None, 28, 28))
|
|
else:
|
|
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
|
|
|
|
x_init = x
|
|
|
|
alpha_prev = tf.placeholder(tf.float32, shape=())
|
|
alpha_new = tf.placeholder(tf.float32, shape=())
|
|
approx_lr = tf.placeholder(tf.float32, shape=())
|
|
|
|
chain_weights = tf.zeros(batch_size)
|
|
# for i in range(1, prop_dist+1):
|
|
# print("processing loop {}".format(i))
|
|
# alpha_prev = (i-1) / prop_dist
|
|
# alpha_new = i / prop_dist
|
|
|
|
prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp)
|
|
prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
|
|
|
chain_weights = -prob_log_new_neg + prob_log_old_neg
|
|
# chain_weights = tf.Print(chain_weights, [chain_weights])
|
|
|
|
# Sample new x using HMC
|
|
def unorm_prob(x):
|
|
return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
|
|
|
for j in range(1):
|
|
x = hmc(x, approx_lr, hmc_step, unorm_prob)
|
|
|
|
return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr
|
|
|
|
|
|
def main():
|
|
|
|
# Initialize dataset
|
|
if FLAGS.dataset == 'cifar10':
|
|
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
|
channel_num = 3
|
|
dim_input = 32 * 32 * 3
|
|
elif FLAGS.dataset == 'imagenet':
|
|
dataset = ImagenetClass()
|
|
channel_num = 3
|
|
dim_input = 64 * 64 * 3
|
|
elif FLAGS.dataset == 'mnist':
|
|
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
|
channel_num = 1
|
|
dim_input = 28 * 28 * 1
|
|
elif FLAGS.dataset == 'dsprites':
|
|
dataset = DSprites()
|
|
channel_num = 1
|
|
dim_input = 64 * 64 * 1
|
|
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
|
dataset = Box2D()
|
|
|
|
dim_output = 1
|
|
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
|
|
|
|
if FLAGS.dataset == 'mnist':
|
|
model = MnistNet(num_channels=channel_num)
|
|
elif FLAGS.dataset == 'cifar10':
|
|
if FLAGS.large_model:
|
|
model = ResNet32Large(num_filters=128)
|
|
elif FLAGS.wider_model:
|
|
model = ResNet32Wider(num_filters=192)
|
|
else:
|
|
model = ResNet32(num_channels=channel_num, num_filters=128)
|
|
elif FLAGS.dataset == 'dsprites':
|
|
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
|
|
|
weights = model.construct_weights('context_{}'.format(0))
|
|
|
|
config = tf.ConfigProto()
|
|
sess = tf.Session(config=config)
|
|
saver = loader = tf.train.Saver(max_to_keep=10)
|
|
|
|
sess.run(tf.global_variables_initializer())
|
|
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
|
|
|
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
|
resume_itr = FLAGS.resume_iter
|
|
|
|
if FLAGS.resume_iter != "-1":
|
|
optimistic_restore(sess, model_file)
|
|
else:
|
|
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
|
# saver.restore(sess, model_file)
|
|
|
|
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
|
|
print("Finished constructing ancestral sample ...................")
|
|
|
|
if FLAGS.dataset != "gauss":
|
|
comb_weights_cum = []
|
|
batch_size = tf.shape(x_init)[0]
|
|
label_tiled = tf.tile(label_default, (batch_size, 1))
|
|
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
|
|
e_pos_list = []
|
|
|
|
for data_corrupt, data, label_gt in tqdm(data_loader):
|
|
e_pos = sess.run([e_compute], {x_init: data})[0]
|
|
e_pos_list.extend(list(e_pos))
|
|
|
|
print(len(e_pos_list))
|
|
print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list))
|
|
|
|
if FLAGS.dataset == "2d":
|
|
alr = 0.0045
|
|
elif FLAGS.dataset == "gauss":
|
|
alr = 0.0085
|
|
elif FLAGS.dataset == "mnist":
|
|
alr = 0.0065
|
|
#90 alr = 0.0035
|
|
else:
|
|
# alr = 0.0125
|
|
if FLAGS.rescale == 8:
|
|
alr = 0.0085
|
|
else:
|
|
alr = 0.0045
|
|
#
|
|
for i in range(1):
|
|
tot_weight = 0
|
|
for j in tqdm(range(1, FLAGS.pdist+1)):
|
|
if j == 1:
|
|
if FLAGS.dataset == "cifar10":
|
|
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
|
|
elif FLAGS.dataset == "gauss":
|
|
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
|
|
elif FLAGS.dataset == "mnist":
|
|
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
|
|
else:
|
|
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
|
|
|
|
alpha_prev = (j-1) / FLAGS.pdist
|
|
alpha_new = j / FLAGS.pdist
|
|
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
|
tot_weight = tot_weight + cweight
|
|
|
|
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
|
|
|
|
tot_weight = 0
|
|
|
|
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
|
alpha_new = (j-1) / FLAGS.pdist
|
|
alpha_prev = j / FLAGS.pdist
|
|
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
|
tot_weight = tot_weight - cweight
|
|
|
|
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|