mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-05-04 23:55:18 -04:00
add cool gpt3 examples
This commit is contained in:
parent
0e8871e15e
commit
88d4a6f793
16 changed files with 205 additions and 0 deletions
249
EBMs/ais.py
Normal file
249
EBMs/ais.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
from hmc import hmc
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
|
||||
from data import Cifar10, Mnist, DSprites
|
||||
from scipy.misc import logsumexp
|
||||
from scipy.misc import imsave
|
||||
from utils import optimistic_restore
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
||||
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
|
||||
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
|
||||
|
||||
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
|
||||
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
|
||||
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
|
||||
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
label_default = np.eye(10)[0:1, :]
|
||||
label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
||||
|
||||
|
||||
def unscale_im(im):
|
||||
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
||||
|
||||
def gauss_prob_log(x, prec=1.0):
|
||||
|
||||
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
||||
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
|
||||
|
||||
return norm_constant_log + prob_density_log
|
||||
|
||||
|
||||
def uniform_prob_log(x):
|
||||
|
||||
return tf.zeros(1)
|
||||
|
||||
|
||||
def model_prob_log(x, e_func, weights, temp):
|
||||
if FLAGS.cclass:
|
||||
batch_size = tf.shape(x)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_raw = e_func.forward(x, weights, label=label_tiled)
|
||||
else:
|
||||
e_raw = e_func.forward(x, weights)
|
||||
energy = tf.reduce_sum(e_raw, axis=[1])
|
||||
return -temp * energy
|
||||
|
||||
|
||||
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
||||
|
||||
if FLAGS.dataset == "gauss":
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
|
||||
else:
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
|
||||
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
|
||||
|
||||
|
||||
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
|
||||
else:
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
|
||||
|
||||
return -norm_prob + oob_prob
|
||||
|
||||
|
||||
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
|
||||
if FLAGS.dataset == "2d":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 2))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim))
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 28, 28))
|
||||
else:
|
||||
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
|
||||
|
||||
x_init = x
|
||||
|
||||
alpha_prev = tf.placeholder(tf.float32, shape=())
|
||||
alpha_new = tf.placeholder(tf.float32, shape=())
|
||||
approx_lr = tf.placeholder(tf.float32, shape=())
|
||||
|
||||
chain_weights = tf.zeros(batch_size)
|
||||
# for i in range(1, prop_dist+1):
|
||||
# print("processing loop {}".format(i))
|
||||
# alpha_prev = (i-1) / prop_dist
|
||||
# alpha_new = i / prop_dist
|
||||
|
||||
prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp)
|
||||
prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
||||
|
||||
chain_weights = -prob_log_new_neg + prob_log_old_neg
|
||||
# chain_weights = tf.Print(chain_weights, [chain_weights])
|
||||
|
||||
# Sample new x using HMC
|
||||
def unorm_prob(x):
|
||||
return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
||||
|
||||
for j in range(1):
|
||||
x = hmc(x, approx_lr, hmc_step, unorm_prob)
|
||||
|
||||
return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Initialize dataset
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 3
|
||||
dim_input = 32 * 32 * 3
|
||||
elif FLAGS.dataset == 'imagenet':
|
||||
dataset = ImagenetClass()
|
||||
channel_num = 3
|
||||
dim_input = 64 * 64 * 3
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 1
|
||||
dim_input = 28 * 28 * 1
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
dataset = DSprites()
|
||||
channel_num = 1
|
||||
dim_input = 64 * 64 * 1
|
||||
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
dataset = Box2D()
|
||||
|
||||
dim_output = 1
|
||||
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
|
||||
|
||||
if FLAGS.dataset == 'mnist':
|
||||
model = MnistNet(num_channels=channel_num)
|
||||
elif FLAGS.dataset == 'cifar10':
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=128)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(num_filters=192)
|
||||
else:
|
||||
model = ResNet32(num_channels=channel_num, num_filters=128)
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
||||
|
||||
weights = model.construct_weights('context_{}'.format(0))
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.Session(config=config)
|
||||
saver = loader = tf.train.Saver(max_to_keep=10)
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
|
||||
if FLAGS.resume_iter != "-1":
|
||||
optimistic_restore(sess, model_file)
|
||||
else:
|
||||
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
||||
# saver.restore(sess, model_file)
|
||||
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
|
||||
print("Finished constructing ancestral sample ...................")
|
||||
|
||||
if FLAGS.dataset != "gauss":
|
||||
comb_weights_cum = []
|
||||
batch_size = tf.shape(x_init)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
|
||||
e_pos_list = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(data_loader):
|
||||
e_pos = sess.run([e_compute], {x_init: data})[0]
|
||||
e_pos_list.extend(list(e_pos))
|
||||
|
||||
print(len(e_pos_list))
|
||||
print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list))
|
||||
|
||||
if FLAGS.dataset == "2d":
|
||||
alr = 0.0045
|
||||
elif FLAGS.dataset == "gauss":
|
||||
alr = 0.0085
|
||||
elif FLAGS.dataset == "mnist":
|
||||
alr = 0.0065
|
||||
#90 alr = 0.0035
|
||||
else:
|
||||
# alr = 0.0125
|
||||
if FLAGS.rescale == 8:
|
||||
alr = 0.0085
|
||||
else:
|
||||
alr = 0.0045
|
||||
#
|
||||
for i in range(1):
|
||||
tot_weight = 0
|
||||
for j in tqdm(range(1, FLAGS.pdist+1)):
|
||||
if j == 1:
|
||||
if FLAGS.dataset == "cifar10":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
|
||||
else:
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
|
||||
|
||||
alpha_prev = (j-1) / FLAGS.pdist
|
||||
alpha_new = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
tot_weight = tot_weight + cweight
|
||||
|
||||
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
tot_weight = 0
|
||||
|
||||
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
||||
alpha_new = (j-1) / FLAGS.pdist
|
||||
alpha_prev = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
tot_weight = tot_weight - cweight
|
||||
|
||||
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue