mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-17 10:40:13 -04:00
chores: refactor for the new ai research, add linter, gh action, etc (#27)
This commit is contained in:
parent
fb4ab80dc3
commit
d5467e559f
40 changed files with 5177 additions and 2476 deletions
228
EBMs/ais.py
228
EBMs/ais.py
|
@ -1,40 +1,65 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from data import Cifar10, DSprites, Mnist
|
||||
from hmc import hmc
|
||||
from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
|
||||
from data import Cifar10, Mnist, DSprites
|
||||
from scipy.misc import logsumexp
|
||||
from scipy.misc import imsave
|
||||
from utils import optimistic_restore
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from utils import optimistic_restore
|
||||
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
||||
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
|
||||
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
|
||||
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
||||
flags.DEFINE_string(
|
||||
"dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_integer(
|
||||
"data_workers", 5, "Number of different data workers to load data in parallel"
|
||||
)
|
||||
flags.DEFINE_integer("batch_size", 16, "Size of inputs")
|
||||
flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
|
||||
|
||||
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
|
||||
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
|
||||
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
|
||||
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
|
||||
flags.DEFINE_bool(
|
||||
"max_pool",
|
||||
False,
|
||||
"Whether or not to use max pooling rather than strided convolutions",
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
"num_filters",
|
||||
64,
|
||||
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
||||
)
|
||||
flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
|
||||
flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
|
||||
flags.DEFINE_integer(
|
||||
"rescale", 1, "factor to rescale input outside of normal (0, 1) box"
|
||||
)
|
||||
flags.DEFINE_float(
|
||||
"temperature", 1, "temperature at which to compute likelihood of model"
|
||||
)
|
||||
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_bool(
|
||||
"cclass",
|
||||
False,
|
||||
"Whether to evaluate the log likelihood of conditional model or not",
|
||||
)
|
||||
flags.DEFINE_bool(
|
||||
"single",
|
||||
False,
|
||||
"Whether to evaluate the log likelihood of conditional model or not",
|
||||
)
|
||||
flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
|
||||
flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
|
||||
flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -45,11 +70,12 @@ label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
|||
def unscale_im(im):
|
||||
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
def gauss_prob_log(x, prec=1.0):
|
||||
|
||||
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
||||
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
|
||||
|
||||
return norm_constant_log + prob_density_log
|
||||
|
||||
|
@ -73,23 +99,36 @@ def model_prob_log(x, e_func, weights, temp):
|
|||
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
||||
|
||||
if FLAGS.dataset == "gauss":
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
|
||||
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
|
||||
x, prec=FLAGS.temperature
|
||||
)
|
||||
else:
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
|
||||
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
|
||||
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
|
||||
x, e_func, weights, temp
|
||||
)
|
||||
# Add an additional log likelihood penalty so that points outside of (0,
|
||||
# 1) box are *highly* unlikely
|
||||
|
||||
|
||||
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
|
||||
if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
|
||||
)
|
||||
elif FLAGS.dataset == "mnist":
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
|
||||
)
|
||||
else:
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
|
||||
axis=[1, 2, 3],
|
||||
)
|
||||
|
||||
return -norm_prob + oob_prob
|
||||
|
||||
|
||||
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
|
||||
def ancestral_sample(
|
||||
e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
|
||||
):
|
||||
if FLAGS.dataset == "2d":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 2))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
|
@ -130,41 +169,46 @@ def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_
|
|||
def main():
|
||||
|
||||
# Initialize dataset
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
if FLAGS.dataset == "cifar10":
|
||||
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 3
|
||||
dim_input = 32 * 32 * 3
|
||||
elif FLAGS.dataset == 'imagenet':
|
||||
32 * 32 * 3
|
||||
elif FLAGS.dataset == "imagenet":
|
||||
dataset = ImagenetClass()
|
||||
channel_num = 3
|
||||
dim_input = 64 * 64 * 3
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
64 * 64 * 3
|
||||
elif FLAGS.dataset == "mnist":
|
||||
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 1
|
||||
dim_input = 28 * 28 * 1
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
28 * 28 * 1
|
||||
elif FLAGS.dataset == "dsprites":
|
||||
dataset = DSprites()
|
||||
channel_num = 1
|
||||
dim_input = 64 * 64 * 1
|
||||
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
64 * 64 * 1
|
||||
elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
||||
dataset = Box2D()
|
||||
|
||||
dim_output = 1
|
||||
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=FLAGS.data_workers,
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
if FLAGS.dataset == 'mnist':
|
||||
if FLAGS.dataset == "mnist":
|
||||
model = MnistNet(num_channels=channel_num)
|
||||
elif FLAGS.dataset == 'cifar10':
|
||||
elif FLAGS.dataset == "cifar10":
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=128)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(num_filters=192)
|
||||
else:
|
||||
model = ResNet32(num_channels=channel_num, num_filters=128)
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
elif FLAGS.dataset == "dsprites":
|
||||
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
||||
|
||||
weights = model.construct_weights('context_{}'.format(0))
|
||||
weights = model.construct_weights("context_{}".format(0))
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.Session(config=config)
|
||||
|
@ -173,8 +217,8 @@ def main():
|
|||
sess.run(tf.global_variables_initializer())
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
|
||||
FLAGS.resume_iter
|
||||
|
||||
if FLAGS.resume_iter != "-1":
|
||||
optimistic_restore(sess, model_file)
|
||||
|
@ -182,14 +226,17 @@ def main():
|
|||
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
||||
# saver.restore(sess, model_file)
|
||||
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
|
||||
model, weights, FLAGS.batch_size, temp=FLAGS.temperature
|
||||
)
|
||||
print("Finished constructing ancestral sample ...................")
|
||||
|
||||
if FLAGS.dataset != "gauss":
|
||||
comb_weights_cum = []
|
||||
batch_size = tf.shape(x_init)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
|
||||
e_compute = -FLAGS.temperature * model.forward(
|
||||
x_init, weights, label=label_tiled
|
||||
)
|
||||
e_pos_list = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(data_loader):
|
||||
|
@ -205,44 +252,75 @@ def main():
|
|||
alr = 0.0085
|
||||
elif FLAGS.dataset == "mnist":
|
||||
alr = 0.0065
|
||||
#90 alr = 0.0035
|
||||
# 90 alr = 0.0035
|
||||
else:
|
||||
# alr = 0.0125
|
||||
if FLAGS.rescale == 8:
|
||||
alr = 0.0085
|
||||
else:
|
||||
alr = 0.0045
|
||||
#
|
||||
#
|
||||
for i in range(1):
|
||||
tot_weight = 0
|
||||
for j in tqdm(range(1, FLAGS.pdist+1)):
|
||||
for j in tqdm(range(1, FLAGS.pdist + 1)):
|
||||
if j == 1:
|
||||
if FLAGS.dataset == "cifar10":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
|
||||
)
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
|
||||
)
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
|
||||
)
|
||||
else:
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
|
||||
)
|
||||
|
||||
alpha_prev = (j-1) / FLAGS.pdist
|
||||
alpha_prev = (j - 1) / FLAGS.pdist
|
||||
alpha_new = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
cweight, x_curr = sess.run(
|
||||
[chain_weights, x],
|
||||
{
|
||||
a_prev: alpha_prev,
|
||||
a_new: alpha_new,
|
||||
x_init: x_curr,
|
||||
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
||||
},
|
||||
)
|
||||
tot_weight = tot_weight + cweight
|
||||
|
||||
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
print(
|
||||
"Total values of lower value based off forward sampling",
|
||||
np.mean(tot_weight),
|
||||
np.std(tot_weight),
|
||||
)
|
||||
|
||||
tot_weight = 0
|
||||
|
||||
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
||||
alpha_new = (j-1) / FLAGS.pdist
|
||||
alpha_new = (j - 1) / FLAGS.pdist
|
||||
alpha_prev = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
cweight, x_curr = sess.run(
|
||||
[chain_weights, x],
|
||||
{
|
||||
a_prev: alpha_prev,
|
||||
a_new: alpha_new,
|
||||
x_init: x_curr,
|
||||
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
||||
},
|
||||
)
|
||||
tot_weight = tot_weight - cweight
|
||||
|
||||
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
print(
|
||||
"Total values of upper value based off backward sampling",
|
||||
np.mean(tot_weight),
|
||||
np.std(tot_weight),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue