mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-04-25 10:09:09 -04:00
942 lines
34 KiB
Python
942 lines
34 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
from tensorflow.python.platform import flags
|
|
|
|
from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader
|
|
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128
|
|
import os.path as osp
|
|
import os
|
|
from baselines.logger import TensorBoardOutputFormat
|
|
from utils import average_gradients, ReplayBuffer, optimistic_restore
|
|
from tqdm import tqdm
|
|
import random
|
|
from torch.utils.data import DataLoader
|
|
import time as time
|
|
from io import StringIO
|
|
from tensorflow.core.util import event_pb2
|
|
import torch
|
|
import numpy as np
|
|
from custom_adam import AdamOptimizer
|
|
from scipy.misc import imsave
|
|
import matplotlib.pyplot as plt
|
|
from hmc import hmc
|
|
|
|
from mpi4py import MPI
|
|
comm = MPI.COMM_WORLD
|
|
rank = comm.Get_rank()
|
|
|
|
import horovod.tensorflow as hvd
|
|
hvd.init()
|
|
|
|
from inception import get_inception_score
|
|
|
|
torch.manual_seed(hvd.rank())
|
|
np.random.seed(hvd.rank())
|
|
tf.set_random_seed(hvd.rank())
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
# Dataset Options
|
|
flags.DEFINE_string('datasource', 'random',
|
|
'initialization for chains, either random or default (decorruption)')
|
|
flags.DEFINE_string('dataset','mnist',
|
|
'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)')
|
|
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
|
|
flags.DEFINE_bool('single', False, 'whether to debug by training on a single image')
|
|
flags.DEFINE_integer('data_workers', 4,
|
|
'Number of different data workers to load data in parallel')
|
|
|
|
# General Experiment Settings
|
|
flags.DEFINE_string('logdir', 'cachedir',
|
|
'location where log of experiments will be stored')
|
|
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
|
flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches')
|
|
flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches')
|
|
flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches')
|
|
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
|
|
flags.DEFINE_bool('train', True, 'whether to train or test')
|
|
flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on')
|
|
flags.DEFINE_float('lr', 3e-4, 'Learning for training')
|
|
flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on')
|
|
|
|
# EBM Specific Experiments Settings
|
|
flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients')
|
|
flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training')
|
|
flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models')
|
|
flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels')
|
|
flags.DEFINE_integer('temperature', 1, 'Temperature for energy function')
|
|
flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),'
|
|
'logsumexp(more stable)'
|
|
'softplus(most stable)')
|
|
flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss')
|
|
|
|
# Setting for MCMC sampling
|
|
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
|
|
flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection')
|
|
flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training')
|
|
flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent')
|
|
flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.')
|
|
flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models')
|
|
flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC')
|
|
flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead')
|
|
|
|
# Architecture Settings
|
|
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets')
|
|
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
|
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
|
flags.DEFINE_bool('large_model', False, 'whether to use a large model')
|
|
flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network')
|
|
flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network')
|
|
|
|
# Dataset settings
|
|
flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images')
|
|
flags.DEFINE_bool('augment', False, 'whether to augmentations to images')
|
|
flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box')
|
|
|
|
# Dsprites specific experiments
|
|
flags.DEFINE_bool('cond_shape', False, 'condition of shape type')
|
|
flags.DEFINE_bool('cond_size', False, 'condition of shape size')
|
|
flags.DEFINE_bool('cond_pos', False, 'condition of position loc')
|
|
flags.DEFINE_bool('cond_rot', False, 'condition of rot')
|
|
|
|
FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale
|
|
|
|
FLAGS.batch_size *= FLAGS.num_gpus
|
|
|
|
print("{} batch size".format(FLAGS.batch_size))
|
|
|
|
|
|
def compress_x_mod(x_mod):
|
|
x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8)
|
|
return x_mod
|
|
|
|
|
|
def decompress_x_mod(x_mod):
|
|
x_mod = x_mod / 256 * FLAGS.rescale + \
|
|
np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape)
|
|
return x_mod
|
|
|
|
|
|
def make_image(tensor):
|
|
"""Convert an numpy representation image to Image protobuf"""
|
|
from PIL import Image
|
|
if len(tensor.shape) == 4:
|
|
_, height, width, channel = tensor.shape
|
|
elif len(tensor.shape) == 3:
|
|
height, width, channel = tensor.shape
|
|
elif len(tensor.shape) == 2:
|
|
height, width = tensor.shape
|
|
channel = 1
|
|
tensor = tensor.astype(np.uint8)
|
|
image = Image.fromarray(tensor)
|
|
import io
|
|
output = io.BytesIO()
|
|
image.save(output, format='PNG')
|
|
image_string = output.getvalue()
|
|
output.close()
|
|
return tf.Summary.Image(height=height,
|
|
width=width,
|
|
colorspace=channel,
|
|
encoded_image_string=image_string)
|
|
|
|
|
|
def log_image(im, logger, tag, step=0):
|
|
im = make_image(im)
|
|
|
|
summary = [tf.Summary.Value(tag=tag, image=im)]
|
|
summary = tf.Summary(value=summary)
|
|
event = event_pb2.Event(summary=summary)
|
|
event.step = step
|
|
logger.writer.WriteEvent(event)
|
|
logger.writer.Flush()
|
|
|
|
|
|
def rescale_im(image):
|
|
image = np.clip(image, 0, FLAGS.rescale)
|
|
if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites':
|
|
return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
|
|
else:
|
|
return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
|
|
|
|
|
|
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
|
|
X = target_vars['X']
|
|
Y = target_vars['Y']
|
|
X_NOISE = target_vars['X_NOISE']
|
|
train_op = target_vars['train_op']
|
|
energy_pos = target_vars['energy_pos']
|
|
energy_neg = target_vars['energy_neg']
|
|
loss_energy = target_vars['loss_energy']
|
|
loss_ml = target_vars['loss_ml']
|
|
loss_total = target_vars['total_loss']
|
|
gvs = target_vars['gvs']
|
|
x_grad = target_vars['x_grad']
|
|
x_grad_first = target_vars['x_grad_first']
|
|
x_off = target_vars['x_off']
|
|
temp = target_vars['temp']
|
|
x_mod = target_vars['x_mod']
|
|
LABEL = target_vars['LABEL']
|
|
LABEL_POS = target_vars['LABEL_POS']
|
|
weights = target_vars['weights']
|
|
test_x_mod = target_vars['test_x_mod']
|
|
eps = target_vars['eps_begin']
|
|
label_ent = target_vars['label_ent']
|
|
|
|
if FLAGS.use_attention:
|
|
gamma = weights[0]['atten']['gamma']
|
|
else:
|
|
gamma = tf.zeros(1)
|
|
|
|
val_output = [test_x_mod]
|
|
|
|
gvs_dict = dict(gvs)
|
|
|
|
log_output = [
|
|
train_op,
|
|
energy_pos,
|
|
energy_neg,
|
|
eps,
|
|
loss_energy,
|
|
loss_ml,
|
|
loss_total,
|
|
x_grad,
|
|
x_off,
|
|
x_mod,
|
|
gamma,
|
|
x_grad_first,
|
|
label_ent,
|
|
*gvs_dict.keys()]
|
|
output = [train_op, x_mod]
|
|
|
|
replay_buffer = ReplayBuffer(10000)
|
|
itr = resume_iter
|
|
x_mod = None
|
|
gd_steps = 1
|
|
|
|
dataloader_iterator = iter(dataloader)
|
|
best_inception = 0.0
|
|
|
|
for epoch in range(FLAGS.epoch_num):
|
|
for data_corrupt, data, label in dataloader:
|
|
data_corrupt = data_corrupt_init = data_corrupt.numpy()
|
|
data_corrupt_init = data_corrupt.copy()
|
|
|
|
data = data.numpy()
|
|
label = label.numpy()
|
|
|
|
label_init = label.copy()
|
|
|
|
if FLAGS.mixup:
|
|
idx = np.random.permutation(data.shape[0])
|
|
lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
|
|
data = data * lam + data[idx] * (1 - lam)
|
|
|
|
if FLAGS.replay_batch and (x_mod is not None):
|
|
replay_buffer.add(compress_x_mod(x_mod))
|
|
|
|
if len(replay_buffer) > FLAGS.batch_size:
|
|
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
|
replay_batch = decompress_x_mod(replay_batch)
|
|
replay_mask = (
|
|
np.random.uniform(
|
|
0,
|
|
FLAGS.rescale,
|
|
FLAGS.batch_size) > 0.05)
|
|
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
|
|
|
if FLAGS.pcd:
|
|
if x_mod is not None:
|
|
data_corrupt = x_mod
|
|
|
|
feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}
|
|
|
|
if FLAGS.cclass:
|
|
feed_dict[LABEL] = label
|
|
feed_dict[LABEL_POS] = label_init
|
|
|
|
if itr % FLAGS.log_interval == 0:
|
|
_, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
|
|
grads = sess.run(log_output, feed_dict)
|
|
|
|
kvs = {}
|
|
kvs['e_pos'] = e_pos.mean()
|
|
kvs['e_pos_std'] = e_pos.std()
|
|
kvs['e_neg'] = e_neg.mean()
|
|
kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
|
|
kvs['e_neg_std'] = e_neg.std()
|
|
kvs['temp'] = temp
|
|
kvs['loss_e'] = loss_e.mean()
|
|
kvs['eps'] = eps.mean()
|
|
kvs['label_ent'] = label_ent
|
|
kvs['loss_ml'] = loss_ml.mean()
|
|
kvs['loss_total'] = loss_total.mean()
|
|
kvs['x_grad'] = np.abs(x_grad).mean()
|
|
kvs['x_grad_first'] = np.abs(x_grad_first).mean()
|
|
kvs['x_off'] = x_off.mean()
|
|
kvs['iter'] = itr
|
|
kvs['gamma'] = gamma
|
|
|
|
for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
|
|
kvs[k] = np.abs(v).max()
|
|
|
|
string = "Obtained a total of "
|
|
for key, value in kvs.items():
|
|
string += "{}: {}, ".format(key, value)
|
|
|
|
if hvd.rank() == 0:
|
|
print(string)
|
|
logger.writekvs(kvs)
|
|
else:
|
|
_, x_mod = sess.run(output, feed_dict)
|
|
|
|
if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
|
|
saver.save(
|
|
sess,
|
|
osp.join(
|
|
FLAGS.logdir,
|
|
FLAGS.exp,
|
|
'model_{}'.format(itr)))
|
|
|
|
if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
|
|
try_im = x_mod
|
|
orig_im = data_corrupt.squeeze()
|
|
actual_im = rescale_im(data)
|
|
|
|
orig_im = rescale_im(orig_im)
|
|
try_im = rescale_im(try_im).squeeze()
|
|
|
|
for i, (im, t_im, actual_im_i) in enumerate(
|
|
zip(orig_im[:20], try_im[:20], actual_im)):
|
|
shape = orig_im.shape[1:]
|
|
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
|
|
size = shape[1]
|
|
new_im[:, :size] = im
|
|
new_im[:, size:2 * size] = t_im
|
|
new_im[:, 2 * size:] = actual_im_i
|
|
|
|
log_image(
|
|
new_im, logger, 'train_gen_{}'.format(itr), step=i)
|
|
|
|
test_im = x_mod
|
|
|
|
try:
|
|
data_corrupt, data, label = next(dataloader_iterator)
|
|
except BaseException:
|
|
dataloader_iterator = iter(dataloader)
|
|
data_corrupt, data, label = next(dataloader_iterator)
|
|
|
|
data_corrupt = data_corrupt.numpy()
|
|
|
|
if FLAGS.replay_batch and (
|
|
x_mod is not None) and len(replay_buffer) > 0:
|
|
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
|
replay_batch = decompress_x_mod(replay_batch)
|
|
replay_mask = (
|
|
np.random.uniform(
|
|
0, 1, (FLAGS.batch_size)) > 0.05)
|
|
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
|
|
|
if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
|
|
n = 128
|
|
|
|
if FLAGS.dataset == "imagenetfull":
|
|
n = 32
|
|
|
|
if len(replay_buffer) > n:
|
|
data_corrupt = decompress_x_mod(replay_buffer.sample(n))
|
|
elif FLAGS.dataset == 'imagenetfull':
|
|
data_corrupt = np.random.uniform(
|
|
0, FLAGS.rescale, (n, 128, 128, 3))
|
|
else:
|
|
data_corrupt = np.random.uniform(
|
|
0, FLAGS.rescale, (n, 32, 32, 3))
|
|
|
|
if FLAGS.dataset == 'cifar10':
|
|
label = np.eye(10)[np.random.randint(0, 10, (n))]
|
|
else:
|
|
label = np.eye(1000)[
|
|
np.random.randint(
|
|
0, 1000, (n))]
|
|
|
|
feed_dict[X_NOISE] = data_corrupt
|
|
|
|
feed_dict[X] = data
|
|
|
|
if FLAGS.cclass:
|
|
feed_dict[LABEL] = label
|
|
|
|
test_x_mod = sess.run(val_output, feed_dict)
|
|
|
|
try_im = test_x_mod
|
|
orig_im = data_corrupt.squeeze()
|
|
actual_im = rescale_im(data.numpy())
|
|
|
|
orig_im = rescale_im(orig_im)
|
|
try_im = rescale_im(try_im).squeeze()
|
|
|
|
for i, (im, t_im, actual_im_i) in enumerate(
|
|
zip(orig_im[:20], try_im[:20], actual_im)):
|
|
|
|
shape = orig_im.shape[1:]
|
|
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
|
|
size = shape[1]
|
|
new_im[:, :size] = im
|
|
new_im[:, size:2 * size] = t_im
|
|
new_im[:, 2 * size:] = actual_im_i
|
|
log_image(
|
|
new_im, logger, 'val_gen_{}'.format(itr), step=i)
|
|
|
|
score, std = get_inception_score(list(try_im), splits=1)
|
|
print(
|
|
"Inception score of {} with std of {}".format(
|
|
score, std))
|
|
kvs = {}
|
|
kvs['inception_score'] = score
|
|
kvs['inception_score_std'] = std
|
|
logger.writekvs(kvs)
|
|
|
|
if score > best_inception:
|
|
best_inception = score
|
|
saver.save(
|
|
sess,
|
|
osp.join(
|
|
FLAGS.logdir,
|
|
FLAGS.exp,
|
|
'model_best'))
|
|
|
|
if itr > 60000 and FLAGS.dataset == "mnist":
|
|
assert False
|
|
itr += 1
|
|
|
|
saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
|
|
|
|
|
|
cifar10_map = {0: 'airplane',
|
|
1: 'automobile',
|
|
2: 'bird',
|
|
3: 'cat',
|
|
4: 'deer',
|
|
5: 'dog',
|
|
6: 'frog',
|
|
7: 'horse',
|
|
8: 'ship',
|
|
9: 'truck'}
|
|
|
|
|
|
def test(target_vars, saver, sess, logger, dataloader):
|
|
X_NOISE = target_vars['X_NOISE']
|
|
X = target_vars['X']
|
|
Y = target_vars['Y']
|
|
LABEL = target_vars['LABEL']
|
|
energy_start = target_vars['energy_start']
|
|
x_mod = target_vars['x_mod']
|
|
x_mod = target_vars['test_x_mod']
|
|
energy_neg = target_vars['energy_neg']
|
|
|
|
np.random.seed(1)
|
|
random.seed(1)
|
|
|
|
output = [x_mod, energy_start, energy_neg]
|
|
|
|
dataloader_iterator = iter(dataloader)
|
|
data_corrupt, data, label = next(dataloader_iterator)
|
|
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
|
|
|
|
orig_im = try_im = data_corrupt
|
|
|
|
if FLAGS.cclass:
|
|
try_im, energy_orig, energy = sess.run(
|
|
output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label})
|
|
else:
|
|
try_im, energy_orig, energy = sess.run(
|
|
output, {X_NOISE: orig_im, Y: label[0:1]})
|
|
|
|
orig_im = rescale_im(orig_im)
|
|
try_im = rescale_im(try_im)
|
|
actual_im = rescale_im(data)
|
|
|
|
for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate(
|
|
zip(orig_im, energy_orig, try_im, energy, label, actual_im)):
|
|
label_i = np.array(label_i)
|
|
|
|
shape = im.shape[1:]
|
|
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
|
|
size = shape[1]
|
|
new_im[:, :size] = im
|
|
new_im[:, size:2 * size] = t_im
|
|
|
|
if FLAGS.cclass:
|
|
label_i = np.where(label_i == 1)[0][0]
|
|
if FLAGS.dataset == 'cifar10':
|
|
log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format(
|
|
i, energy_i[0], energy[0], cifar10_map[label_i]), step=i)
|
|
else:
|
|
log_image(
|
|
new_im,
|
|
logger,
|
|
'{}_{:.4f}_now_{:.4f}_{}'.format(
|
|
i,
|
|
energy_i[0],
|
|
energy[0],
|
|
label_i),
|
|
step=i)
|
|
else:
|
|
log_image(
|
|
new_im,
|
|
logger,
|
|
'{}_{:.4f}_now_{:.4f}'.format(
|
|
i,
|
|
energy_i[0],
|
|
energy[0]),
|
|
step=i)
|
|
|
|
test_ims = list(try_im)
|
|
real_ims = list(actual_im)
|
|
|
|
for i in tqdm(range(50000 // FLAGS.batch_size + 1)):
|
|
try:
|
|
data_corrupt, data, label = dataloader_iterator.next()
|
|
except BaseException:
|
|
dataloader_iterator = iter(dataloader)
|
|
data_corrupt, data, label = dataloader_iterator.next()
|
|
|
|
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
|
|
|
|
if FLAGS.cclass:
|
|
try_im, energy_orig, energy = sess.run(
|
|
output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label})
|
|
else:
|
|
try_im, energy_orig, energy = sess.run(
|
|
output, {X_NOISE: data_corrupt, Y: label[0:1]})
|
|
|
|
try_im = rescale_im(try_im)
|
|
real_im = rescale_im(data)
|
|
|
|
test_ims.extend(list(try_im))
|
|
real_ims.extend(list(real_im))
|
|
|
|
score, std = get_inception_score(test_ims)
|
|
print("Inception score of {} with std of {}".format(score, std))
|
|
|
|
|
|
def main():
|
|
print("Local rank: ", hvd.local_rank(), hvd.size())
|
|
|
|
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
|
if hvd.rank() == 0:
|
|
if not osp.exists(logdir):
|
|
os.makedirs(logdir)
|
|
logger = TensorBoardOutputFormat(logdir)
|
|
else:
|
|
logger = None
|
|
|
|
LABEL = None
|
|
print("Loading data...")
|
|
if FLAGS.dataset == 'cifar10':
|
|
dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
|
|
test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
|
channel_num = 3
|
|
|
|
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
|
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
|
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
|
|
|
if FLAGS.large_model:
|
|
model = ResNet32Large(
|
|
num_channels=channel_num,
|
|
num_filters=128,
|
|
train=True)
|
|
elif FLAGS.larger_model:
|
|
model = ResNet32Larger(
|
|
num_channels=channel_num,
|
|
num_filters=128)
|
|
elif FLAGS.wider_model:
|
|
model = ResNet32Wider(
|
|
num_channels=channel_num,
|
|
num_filters=192)
|
|
else:
|
|
model = ResNet32(
|
|
num_channels=channel_num,
|
|
num_filters=128)
|
|
|
|
elif FLAGS.dataset == 'imagenet':
|
|
dataset = Imagenet(train=True)
|
|
test_dataset = Imagenet(train=False)
|
|
channel_num = 3
|
|
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
|
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
|
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
|
|
|
model = ResNet32Wider(
|
|
num_channels=channel_num,
|
|
num_filters=256)
|
|
|
|
elif FLAGS.dataset == 'imagenetfull':
|
|
channel_num = 3
|
|
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
|
X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
|
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
|
|
|
model = ResNet128(
|
|
num_channels=channel_num,
|
|
num_filters=64)
|
|
|
|
elif FLAGS.dataset == 'mnist':
|
|
dataset = Mnist(rescale=FLAGS.rescale)
|
|
test_dataset = dataset
|
|
channel_num = 1
|
|
X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
|
|
X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
|
|
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
|
|
|
model = MnistNet(
|
|
num_channels=channel_num,
|
|
num_filters=FLAGS.num_filters)
|
|
|
|
elif FLAGS.dataset == 'dsprites':
|
|
dataset = DSprites(
|
|
cond_shape=FLAGS.cond_shape,
|
|
cond_size=FLAGS.cond_size,
|
|
cond_pos=FLAGS.cond_pos,
|
|
cond_rot=FLAGS.cond_rot)
|
|
test_dataset = dataset
|
|
channel_num = 1
|
|
|
|
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
|
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
|
|
|
if FLAGS.dpos_only:
|
|
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
elif FLAGS.dsize_only:
|
|
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
|
elif FLAGS.drot_only:
|
|
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
elif FLAGS.cond_size:
|
|
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
|
elif FLAGS.cond_shape:
|
|
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
|
elif FLAGS.cond_pos:
|
|
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
elif FLAGS.cond_rot:
|
|
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
|
else:
|
|
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
|
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
|
|
|
model = DspritesNet(
|
|
num_channels=channel_num,
|
|
num_filters=FLAGS.num_filters,
|
|
cond_size=FLAGS.cond_size,
|
|
cond_shape=FLAGS.cond_shape,
|
|
cond_pos=FLAGS.cond_pos,
|
|
cond_rot=FLAGS.cond_rot)
|
|
|
|
print("Done loading...")
|
|
|
|
if FLAGS.dataset == "imagenetfull":
|
|
# In the case of full imagenet, use custom_tensorflow dataloader
|
|
data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale)
|
|
else:
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
batch_size=FLAGS.batch_size,
|
|
num_workers=FLAGS.data_workers,
|
|
drop_last=True,
|
|
shuffle=True)
|
|
|
|
batch_size = FLAGS.batch_size
|
|
|
|
weights = [model.construct_weights('context_0')]
|
|
|
|
Y = tf.placeholder(shape=(None), dtype=tf.int32)
|
|
|
|
# Varibles to run in training
|
|
X_SPLIT = tf.split(X, FLAGS.num_gpus)
|
|
X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
|
|
LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
|
|
LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
|
|
LABEL_SPLIT_INIT = list(LABEL_SPLIT)
|
|
tower_grads = []
|
|
tower_gen_grads = []
|
|
x_mod_list = []
|
|
|
|
optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
|
|
optimizer = hvd.DistributedOptimizer(optimizer)
|
|
|
|
for j in range(FLAGS.num_gpus):
|
|
|
|
if FLAGS.model_cclass:
|
|
ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
|
|
label_tensor = tf.Variable(
|
|
tf.convert_to_tensor(
|
|
np.reshape(
|
|
np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
|
|
(FLAGS.batch_size * 10, 10)),
|
|
dtype=tf.float32),
|
|
trainable=False,
|
|
dtype=tf.float32)
|
|
x_split = tf.tile(
|
|
tf.reshape(
|
|
X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
|
|
x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
|
|
energy_pos = model.forward(
|
|
x_split,
|
|
weights[0],
|
|
label=label_tensor,
|
|
stop_at_grad=False)
|
|
|
|
energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
|
|
energy_partition_est = tf.reduce_logsumexp(
|
|
energy_pos_full, axis=1, keepdims=True)
|
|
uniform = tf.random_uniform(tf.shape(energy_pos_full))
|
|
label_tensor = tf.argmax(-energy_pos_full -
|
|
tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
|
|
label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
|
|
label = tf.Print(label, [label_tensor, energy_pos_full])
|
|
LABEL_SPLIT[j] = label
|
|
energy_pos = tf.concat(energy_pos, axis=0)
|
|
else:
|
|
energy_pos = [
|
|
model.forward(
|
|
X_SPLIT[j],
|
|
weights[0],
|
|
label=LABEL_POS_SPLIT[j],
|
|
stop_at_grad=False)]
|
|
energy_pos = tf.concat(energy_pos, axis=0)
|
|
|
|
print("Building graph...")
|
|
x_mod = x_orig = X_NOISE_SPLIT[j]
|
|
|
|
x_grads = []
|
|
|
|
energy_negs = []
|
|
loss_energys = []
|
|
|
|
energy_negs.extend([model.forward(tf.stop_gradient(
|
|
x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
|
|
eps_begin = tf.zeros(1)
|
|
|
|
steps = tf.constant(0)
|
|
c = lambda i, x: tf.less(i, FLAGS.num_steps)
|
|
|
|
def langevin_step(counter, x_mod):
|
|
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
|
|
mean=0.0,
|
|
stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)
|
|
|
|
energy_noise = energy_start = tf.concat(
|
|
[model.forward(
|
|
x_mod,
|
|
weights[0],
|
|
label=LABEL_SPLIT[j],
|
|
reuse=True,
|
|
stop_at_grad=False,
|
|
stop_batch=True)],
|
|
axis=0)
|
|
|
|
x_grad, label_grad = tf.gradients(
|
|
FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
|
|
energy_noise_old = energy_noise
|
|
|
|
lr = FLAGS.step_lr
|
|
|
|
if FLAGS.proj_norm != 0.0:
|
|
if FLAGS.proj_norm_type == 'l2':
|
|
x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
|
|
elif FLAGS.proj_norm_type == 'li':
|
|
x_grad = tf.clip_by_value(
|
|
x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
|
else:
|
|
print("Other types of projection are not supported!!!")
|
|
assert False
|
|
|
|
# Clip gradient norm for now
|
|
if FLAGS.hmc:
|
|
# Step size should be tuned to get around 65% acceptance
|
|
def energy(x):
|
|
return FLAGS.temperature * \
|
|
model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)
|
|
|
|
x_last = hmc(x_mod, 15., 10, energy)
|
|
else:
|
|
x_last = x_mod - (lr) * x_grad
|
|
|
|
x_mod = x_last
|
|
x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)
|
|
|
|
counter = counter + 1
|
|
|
|
return counter, x_mod
|
|
|
|
steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))
|
|
|
|
energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
|
|
stop_at_grad=False, reuse=True)
|
|
x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
|
|
x_grads.append(x_grad)
|
|
|
|
energy_negs.append(
|
|
model.forward(
|
|
tf.stop_gradient(x_mod),
|
|
weights[0],
|
|
label=LABEL_SPLIT[j],
|
|
stop_at_grad=False,
|
|
reuse=True))
|
|
|
|
test_x_mod = x_mod
|
|
|
|
temp = FLAGS.temperature
|
|
|
|
energy_neg = energy_negs[-1]
|
|
x_off = tf.reduce_mean(
|
|
tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))
|
|
|
|
loss_energy = model.forward(
|
|
x_mod,
|
|
weights[0],
|
|
reuse=True,
|
|
label=LABEL,
|
|
stop_grad=True)
|
|
|
|
print("Finished processing loop construction ...")
|
|
|
|
target_vars = {}
|
|
|
|
if FLAGS.cclass or FLAGS.model_cclass:
|
|
label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
|
|
label_prob = label_sum / tf.reduce_sum(label_sum)
|
|
label_ent = -tf.reduce_sum(label_prob *
|
|
tf.math.log(label_prob + 1e-7))
|
|
else:
|
|
label_ent = tf.zeros(1)
|
|
|
|
target_vars['label_ent'] = label_ent
|
|
|
|
if FLAGS.train:
|
|
|
|
if FLAGS.objective == 'logsumexp':
|
|
pos_term = temp * energy_pos
|
|
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
|
|
coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
|
|
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
|
pos_loss = tf.reduce_mean(temp * energy_pos)
|
|
neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
|
|
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
|
|
elif FLAGS.objective == 'cd':
|
|
pos_loss = tf.reduce_mean(temp * energy_pos)
|
|
neg_loss = -tf.reduce_mean(temp * energy_neg)
|
|
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
|
|
elif FLAGS.objective == 'softplus':
|
|
loss_ml = FLAGS.ml_coeff * \
|
|
tf.nn.softplus(temp * (energy_pos - energy_neg))
|
|
|
|
loss_total = tf.reduce_mean(loss_ml)
|
|
|
|
if not FLAGS.zero_kl:
|
|
loss_total = loss_total + tf.reduce_mean(loss_energy)
|
|
|
|
loss_total = loss_total + \
|
|
FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))
|
|
|
|
print("Started gradient computation...")
|
|
gvs = optimizer.compute_gradients(loss_total)
|
|
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
|
|
|
print("Applying gradients...")
|
|
|
|
tower_grads.append(gvs)
|
|
|
|
print("Finished applying gradients.")
|
|
|
|
target_vars['loss_ml'] = loss_ml
|
|
target_vars['total_loss'] = loss_total
|
|
target_vars['loss_energy'] = loss_energy
|
|
target_vars['weights'] = weights
|
|
target_vars['gvs'] = gvs
|
|
|
|
target_vars['X'] = X
|
|
target_vars['Y'] = Y
|
|
target_vars['LABEL'] = LABEL
|
|
target_vars['LABEL_POS'] = LABEL_POS
|
|
target_vars['X_NOISE'] = X_NOISE
|
|
target_vars['energy_pos'] = energy_pos
|
|
target_vars['energy_start'] = energy_negs[0]
|
|
|
|
if len(x_grads) >= 1:
|
|
target_vars['x_grad'] = x_grads[-1]
|
|
target_vars['x_grad_first'] = x_grads[0]
|
|
else:
|
|
target_vars['x_grad'] = tf.zeros(1)
|
|
target_vars['x_grad_first'] = tf.zeros(1)
|
|
|
|
target_vars['x_mod'] = x_mod
|
|
target_vars['x_off'] = x_off
|
|
target_vars['temp'] = temp
|
|
target_vars['energy_neg'] = energy_neg
|
|
target_vars['test_x_mod'] = test_x_mod
|
|
target_vars['eps_begin'] = eps_begin
|
|
|
|
if FLAGS.train:
|
|
grads = average_gradients(tower_grads)
|
|
train_op = optimizer.apply_gradients(grads)
|
|
target_vars['train_op'] = train_op
|
|
|
|
config = tf.ConfigProto()
|
|
|
|
if hvd.size() > 1:
|
|
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
|
|
|
sess = tf.Session(config=config)
|
|
|
|
saver = loader = tf.train.Saver(
|
|
max_to_keep=30, keep_checkpoint_every_n_hours=6)
|
|
|
|
total_parameters = 0
|
|
for variable in tf.trainable_variables():
|
|
# shape is an array of tf.Dimension
|
|
shape = variable.get_shape()
|
|
variable_parameters = 1
|
|
for dim in shape:
|
|
variable_parameters *= dim.value
|
|
total_parameters += variable_parameters
|
|
print("Model has a total of {} parameters".format(total_parameters))
|
|
|
|
sess.run(tf.global_variables_initializer())
|
|
|
|
resume_itr = 0
|
|
|
|
if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
|
|
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
|
resume_itr = FLAGS.resume_iter
|
|
# saver.restore(sess, model_file)
|
|
optimistic_restore(sess, model_file)
|
|
|
|
sess.run(hvd.broadcast_global_variables(0))
|
|
print("Initializing variables...")
|
|
|
|
print("Start broadcast")
|
|
print("End broadcast")
|
|
|
|
if FLAGS.train:
|
|
train(target_vars, saver, sess,
|
|
logger, data_loader, resume_itr,
|
|
logdir)
|
|
|
|
test(target_vars, saver, sess, logger, data_loader)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|