2024-11-17 17:45:23 -08:00

942 lines
34 KiB
Python

import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128
import os.path as osp
import os
from baselines.logger import TensorBoardOutputFormat
from utils import average_gradients, ReplayBuffer, optimistic_restore
from tqdm import tqdm
import random
from torch.utils.data import DataLoader
import time as time
from io import StringIO
from tensorflow.core.util import event_pb2
import torch
import numpy as np
from custom_adam import AdamOptimizer
from scipy.misc import imsave
import matplotlib.pyplot as plt
from hmc import hmc
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
import horovod.tensorflow as hvd
hvd.init()
from inception import get_inception_score
torch.manual_seed(hvd.rank())
np.random.seed(hvd.rank())
tf.set_random_seed(hvd.rank())
FLAGS = flags.FLAGS
# Dataset Options
flags.DEFINE_string('datasource', 'random',
'initialization for chains, either random or default (decorruption)')
flags.DEFINE_string('dataset','mnist',
'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)')
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
flags.DEFINE_bool('single', False, 'whether to debug by training on a single image')
flags.DEFINE_integer('data_workers', 4,
'Number of different data workers to load data in parallel')
# General Experiment Settings
flags.DEFINE_string('logdir', 'cachedir',
'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches')
flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches')
flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('train', True, 'whether to train or test')
flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on')
flags.DEFINE_float('lr', 3e-4, 'Learning for training')
flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on')
# EBM Specific Experiments Settings
flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients')
flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training')
flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models')
flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels')
flags.DEFINE_integer('temperature', 1, 'Temperature for energy function')
flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),'
'logsumexp(more stable)'
'softplus(most stable)')
flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss')
# Setting for MCMC sampling
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection')
flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training')
flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent')
flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.')
flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models')
flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC')
flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead')
# Architecture Settings
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('large_model', False, 'whether to use a large model')
flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network')
flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network')
# Dataset settings
flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images')
flags.DEFINE_bool('augment', False, 'whether to augmentations to images')
flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box')
# Dsprites specific experiments
flags.DEFINE_bool('cond_shape', False, 'condition of shape type')
flags.DEFINE_bool('cond_size', False, 'condition of shape size')
flags.DEFINE_bool('cond_pos', False, 'condition of position loc')
flags.DEFINE_bool('cond_rot', False, 'condition of rot')
FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale
FLAGS.batch_size *= FLAGS.num_gpus
print("{} batch size".format(FLAGS.batch_size))
def compress_x_mod(x_mod):
x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8)
return x_mod
def decompress_x_mod(x_mod):
x_mod = x_mod / 256 * FLAGS.rescale + \
np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape)
return x_mod
def make_image(tensor):
"""Convert an numpy representation image to Image protobuf"""
from PIL import Image
if len(tensor.shape) == 4:
_, height, width, channel = tensor.shape
elif len(tensor.shape) == 3:
height, width, channel = tensor.shape
elif len(tensor.shape) == 2:
height, width = tensor.shape
channel = 1
tensor = tensor.astype(np.uint8)
image = Image.fromarray(tensor)
import io
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
return tf.Summary.Image(height=height,
width=width,
colorspace=channel,
encoded_image_string=image_string)
def log_image(im, logger, tag, step=0):
im = make_image(im)
summary = [tf.Summary.Value(tag=tag, image=im)]
summary = tf.Summary(value=summary)
event = event_pb2.Event(summary=summary)
event.step = step
logger.writer.WriteEvent(event)
logger.writer.Flush()
def rescale_im(image):
image = np.clip(image, 0, FLAGS.rescale)
if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites':
return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
else:
return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
X = target_vars['X']
Y = target_vars['Y']
X_NOISE = target_vars['X_NOISE']
train_op = target_vars['train_op']
energy_pos = target_vars['energy_pos']
energy_neg = target_vars['energy_neg']
loss_energy = target_vars['loss_energy']
loss_ml = target_vars['loss_ml']
loss_total = target_vars['total_loss']
gvs = target_vars['gvs']
x_grad = target_vars['x_grad']
x_grad_first = target_vars['x_grad_first']
x_off = target_vars['x_off']
temp = target_vars['temp']
x_mod = target_vars['x_mod']
LABEL = target_vars['LABEL']
LABEL_POS = target_vars['LABEL_POS']
weights = target_vars['weights']
test_x_mod = target_vars['test_x_mod']
eps = target_vars['eps_begin']
label_ent = target_vars['label_ent']
if FLAGS.use_attention:
gamma = weights[0]['atten']['gamma']
else:
gamma = tf.zeros(1)
val_output = [test_x_mod]
gvs_dict = dict(gvs)
log_output = [
train_op,
energy_pos,
energy_neg,
eps,
loss_energy,
loss_ml,
loss_total,
x_grad,
x_off,
x_mod,
gamma,
x_grad_first,
label_ent,
*gvs_dict.keys()]
output = [train_op, x_mod]
replay_buffer = ReplayBuffer(10000)
itr = resume_iter
x_mod = None
gd_steps = 1
dataloader_iterator = iter(dataloader)
best_inception = 0.0
for epoch in range(FLAGS.epoch_num):
for data_corrupt, data, label in dataloader:
data_corrupt = data_corrupt_init = data_corrupt.numpy()
data_corrupt_init = data_corrupt.copy()
data = data.numpy()
label = label.numpy()
label_init = label.copy()
if FLAGS.mixup:
idx = np.random.permutation(data.shape[0])
lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
data = data * lam + data[idx] * (1 - lam)
if FLAGS.replay_batch and (x_mod is not None):
replay_buffer.add(compress_x_mod(x_mod))
if len(replay_buffer) > FLAGS.batch_size:
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_batch = decompress_x_mod(replay_batch)
replay_mask = (
np.random.uniform(
0,
FLAGS.rescale,
FLAGS.batch_size) > 0.05)
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.pcd:
if x_mod is not None:
data_corrupt = x_mod
feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}
if FLAGS.cclass:
feed_dict[LABEL] = label
feed_dict[LABEL_POS] = label_init
if itr % FLAGS.log_interval == 0:
_, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
grads = sess.run(log_output, feed_dict)
kvs = {}
kvs['e_pos'] = e_pos.mean()
kvs['e_pos_std'] = e_pos.std()
kvs['e_neg'] = e_neg.mean()
kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
kvs['e_neg_std'] = e_neg.std()
kvs['temp'] = temp
kvs['loss_e'] = loss_e.mean()
kvs['eps'] = eps.mean()
kvs['label_ent'] = label_ent
kvs['loss_ml'] = loss_ml.mean()
kvs['loss_total'] = loss_total.mean()
kvs['x_grad'] = np.abs(x_grad).mean()
kvs['x_grad_first'] = np.abs(x_grad_first).mean()
kvs['x_off'] = x_off.mean()
kvs['iter'] = itr
kvs['gamma'] = gamma
for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
kvs[k] = np.abs(v).max()
string = "Obtained a total of "
for key, value in kvs.items():
string += "{}: {}, ".format(key, value)
if hvd.rank() == 0:
print(string)
logger.writekvs(kvs)
else:
_, x_mod = sess.run(output, feed_dict)
if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
saver.save(
sess,
osp.join(
FLAGS.logdir,
FLAGS.exp,
'model_{}'.format(itr)))
if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
try_im = x_mod
orig_im = data_corrupt.squeeze()
actual_im = rescale_im(data)
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im).squeeze()
for i, (im, t_im, actual_im_i) in enumerate(
zip(orig_im[:20], try_im[:20], actual_im)):
shape = orig_im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
new_im[:, size:2 * size] = t_im
new_im[:, 2 * size:] = actual_im_i
log_image(
new_im, logger, 'train_gen_{}'.format(itr), step=i)
test_im = x_mod
try:
data_corrupt, data, label = next(dataloader_iterator)
except BaseException:
dataloader_iterator = iter(dataloader)
data_corrupt, data, label = next(dataloader_iterator)
data_corrupt = data_corrupt.numpy()
if FLAGS.replay_batch and (
x_mod is not None) and len(replay_buffer) > 0:
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_batch = decompress_x_mod(replay_batch)
replay_mask = (
np.random.uniform(
0, 1, (FLAGS.batch_size)) > 0.05)
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
n = 128
if FLAGS.dataset == "imagenetfull":
n = 32
if len(replay_buffer) > n:
data_corrupt = decompress_x_mod(replay_buffer.sample(n))
elif FLAGS.dataset == 'imagenetfull':
data_corrupt = np.random.uniform(
0, FLAGS.rescale, (n, 128, 128, 3))
else:
data_corrupt = np.random.uniform(
0, FLAGS.rescale, (n, 32, 32, 3))
if FLAGS.dataset == 'cifar10':
label = np.eye(10)[np.random.randint(0, 10, (n))]
else:
label = np.eye(1000)[
np.random.randint(
0, 1000, (n))]
feed_dict[X_NOISE] = data_corrupt
feed_dict[X] = data
if FLAGS.cclass:
feed_dict[LABEL] = label
test_x_mod = sess.run(val_output, feed_dict)
try_im = test_x_mod
orig_im = data_corrupt.squeeze()
actual_im = rescale_im(data.numpy())
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im).squeeze()
for i, (im, t_im, actual_im_i) in enumerate(
zip(orig_im[:20], try_im[:20], actual_im)):
shape = orig_im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
new_im[:, size:2 * size] = t_im
new_im[:, 2 * size:] = actual_im_i
log_image(
new_im, logger, 'val_gen_{}'.format(itr), step=i)
score, std = get_inception_score(list(try_im), splits=1)
print(
"Inception score of {} with std of {}".format(
score, std))
kvs = {}
kvs['inception_score'] = score
kvs['inception_score_std'] = std
logger.writekvs(kvs)
if score > best_inception:
best_inception = score
saver.save(
sess,
osp.join(
FLAGS.logdir,
FLAGS.exp,
'model_best'))
if itr > 60000 and FLAGS.dataset == "mnist":
assert False
itr += 1
saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
cifar10_map = {0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'}
def test(target_vars, saver, sess, logger, dataloader):
X_NOISE = target_vars['X_NOISE']
X = target_vars['X']
Y = target_vars['Y']
LABEL = target_vars['LABEL']
energy_start = target_vars['energy_start']
x_mod = target_vars['x_mod']
x_mod = target_vars['test_x_mod']
energy_neg = target_vars['energy_neg']
np.random.seed(1)
random.seed(1)
output = [x_mod, energy_start, energy_neg]
dataloader_iterator = iter(dataloader)
data_corrupt, data, label = next(dataloader_iterator)
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
orig_im = try_im = data_corrupt
if FLAGS.cclass:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label})
else:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: orig_im, Y: label[0:1]})
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im)
actual_im = rescale_im(data)
for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate(
zip(orig_im, energy_orig, try_im, energy, label, actual_im)):
label_i = np.array(label_i)
shape = im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
new_im[:, size:2 * size] = t_im
if FLAGS.cclass:
label_i = np.where(label_i == 1)[0][0]
if FLAGS.dataset == 'cifar10':
log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format(
i, energy_i[0], energy[0], cifar10_map[label_i]), step=i)
else:
log_image(
new_im,
logger,
'{}_{:.4f}_now_{:.4f}_{}'.format(
i,
energy_i[0],
energy[0],
label_i),
step=i)
else:
log_image(
new_im,
logger,
'{}_{:.4f}_now_{:.4f}'.format(
i,
energy_i[0],
energy[0]),
step=i)
test_ims = list(try_im)
real_ims = list(actual_im)
for i in tqdm(range(50000 // FLAGS.batch_size + 1)):
try:
data_corrupt, data, label = dataloader_iterator.next()
except BaseException:
dataloader_iterator = iter(dataloader)
data_corrupt, data, label = dataloader_iterator.next()
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
if FLAGS.cclass:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label})
else:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: data_corrupt, Y: label[0:1]})
try_im = rescale_im(try_im)
real_im = rescale_im(data)
test_ims.extend(list(try_im))
real_ims.extend(list(real_im))
score, std = get_inception_score(test_ims)
print("Inception score of {} with std of {}".format(score, std))
def main():
print("Local rank: ", hvd.local_rank(), hvd.size())
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if hvd.rank() == 0:
if not osp.exists(logdir):
os.makedirs(logdir)
logger = TensorBoardOutputFormat(logdir)
else:
logger = None
LABEL = None
print("Loading data...")
if FLAGS.dataset == 'cifar10':
dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
if FLAGS.large_model:
model = ResNet32Large(
num_channels=channel_num,
num_filters=128,
train=True)
elif FLAGS.larger_model:
model = ResNet32Larger(
num_channels=channel_num,
num_filters=128)
elif FLAGS.wider_model:
model = ResNet32Wider(
num_channels=channel_num,
num_filters=192)
else:
model = ResNet32(
num_channels=channel_num,
num_filters=128)
elif FLAGS.dataset == 'imagenet':
dataset = Imagenet(train=True)
test_dataset = Imagenet(train=False)
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
model = ResNet32Wider(
num_channels=channel_num,
num_filters=256)
elif FLAGS.dataset == 'imagenetfull':
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
model = ResNet128(
num_channels=channel_num,
num_filters=64)
elif FLAGS.dataset == 'mnist':
dataset = Mnist(rescale=FLAGS.rescale)
test_dataset = dataset
channel_num = 1
X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
model = MnistNet(
num_channels=channel_num,
num_filters=FLAGS.num_filters)
elif FLAGS.dataset == 'dsprites':
dataset = DSprites(
cond_shape=FLAGS.cond_shape,
cond_size=FLAGS.cond_size,
cond_pos=FLAGS.cond_pos,
cond_rot=FLAGS.cond_rot)
test_dataset = dataset
channel_num = 1
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
if FLAGS.dpos_only:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
elif FLAGS.dsize_only:
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
elif FLAGS.drot_only:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
elif FLAGS.cond_size:
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
elif FLAGS.cond_shape:
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
elif FLAGS.cond_pos:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
elif FLAGS.cond_rot:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
else:
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
model = DspritesNet(
num_channels=channel_num,
num_filters=FLAGS.num_filters,
cond_size=FLAGS.cond_size,
cond_shape=FLAGS.cond_shape,
cond_pos=FLAGS.cond_pos,
cond_rot=FLAGS.cond_rot)
print("Done loading...")
if FLAGS.dataset == "imagenetfull":
# In the case of full imagenet, use custom_tensorflow dataloader
data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale)
else:
data_loader = DataLoader(
dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.data_workers,
drop_last=True,
shuffle=True)
batch_size = FLAGS.batch_size
weights = [model.construct_weights('context_0')]
Y = tf.placeholder(shape=(None), dtype=tf.int32)
# Varibles to run in training
X_SPLIT = tf.split(X, FLAGS.num_gpus)
X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
LABEL_SPLIT_INIT = list(LABEL_SPLIT)
tower_grads = []
tower_gen_grads = []
x_mod_list = []
optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
optimizer = hvd.DistributedOptimizer(optimizer)
for j in range(FLAGS.num_gpus):
if FLAGS.model_cclass:
ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
label_tensor = tf.Variable(
tf.convert_to_tensor(
np.reshape(
np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
(FLAGS.batch_size * 10, 10)),
dtype=tf.float32),
trainable=False,
dtype=tf.float32)
x_split = tf.tile(
tf.reshape(
X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
energy_pos = model.forward(
x_split,
weights[0],
label=label_tensor,
stop_at_grad=False)
energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
energy_partition_est = tf.reduce_logsumexp(
energy_pos_full, axis=1, keepdims=True)
uniform = tf.random_uniform(tf.shape(energy_pos_full))
label_tensor = tf.argmax(-energy_pos_full -
tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
label = tf.Print(label, [label_tensor, energy_pos_full])
LABEL_SPLIT[j] = label
energy_pos = tf.concat(energy_pos, axis=0)
else:
energy_pos = [
model.forward(
X_SPLIT[j],
weights[0],
label=LABEL_POS_SPLIT[j],
stop_at_grad=False)]
energy_pos = tf.concat(energy_pos, axis=0)
print("Building graph...")
x_mod = x_orig = X_NOISE_SPLIT[j]
x_grads = []
energy_negs = []
loss_energys = []
energy_negs.extend([model.forward(tf.stop_gradient(
x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
eps_begin = tf.zeros(1)
steps = tf.constant(0)
c = lambda i, x: tf.less(i, FLAGS.num_steps)
def langevin_step(counter, x_mod):
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
mean=0.0,
stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)
energy_noise = energy_start = tf.concat(
[model.forward(
x_mod,
weights[0],
label=LABEL_SPLIT[j],
reuse=True,
stop_at_grad=False,
stop_batch=True)],
axis=0)
x_grad, label_grad = tf.gradients(
FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
energy_noise_old = energy_noise
lr = FLAGS.step_lr
if FLAGS.proj_norm != 0.0:
if FLAGS.proj_norm_type == 'l2':
x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
elif FLAGS.proj_norm_type == 'li':
x_grad = tf.clip_by_value(
x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
else:
print("Other types of projection are not supported!!!")
assert False
# Clip gradient norm for now
if FLAGS.hmc:
# Step size should be tuned to get around 65% acceptance
def energy(x):
return FLAGS.temperature * \
model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)
x_last = hmc(x_mod, 15., 10, energy)
else:
x_last = x_mod - (lr) * x_grad
x_mod = x_last
x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)
counter = counter + 1
return counter, x_mod
steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))
energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
stop_at_grad=False, reuse=True)
x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
x_grads.append(x_grad)
energy_negs.append(
model.forward(
tf.stop_gradient(x_mod),
weights[0],
label=LABEL_SPLIT[j],
stop_at_grad=False,
reuse=True))
test_x_mod = x_mod
temp = FLAGS.temperature
energy_neg = energy_negs[-1]
x_off = tf.reduce_mean(
tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))
loss_energy = model.forward(
x_mod,
weights[0],
reuse=True,
label=LABEL,
stop_grad=True)
print("Finished processing loop construction ...")
target_vars = {}
if FLAGS.cclass or FLAGS.model_cclass:
label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
label_prob = label_sum / tf.reduce_sum(label_sum)
label_ent = -tf.reduce_sum(label_prob *
tf.math.log(label_prob + 1e-7))
else:
label_ent = tf.zeros(1)
target_vars['label_ent'] = label_ent
if FLAGS.train:
if FLAGS.objective == 'logsumexp':
pos_term = temp * energy_pos
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
pos_loss = tf.reduce_mean(temp * energy_pos)
neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
elif FLAGS.objective == 'cd':
pos_loss = tf.reduce_mean(temp * energy_pos)
neg_loss = -tf.reduce_mean(temp * energy_neg)
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
elif FLAGS.objective == 'softplus':
loss_ml = FLAGS.ml_coeff * \
tf.nn.softplus(temp * (energy_pos - energy_neg))
loss_total = tf.reduce_mean(loss_ml)
if not FLAGS.zero_kl:
loss_total = loss_total + tf.reduce_mean(loss_energy)
loss_total = loss_total + \
FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))
print("Started gradient computation...")
gvs = optimizer.compute_gradients(loss_total)
gvs = [(k, v) for (k, v) in gvs if k is not None]
print("Applying gradients...")
tower_grads.append(gvs)
print("Finished applying gradients.")
target_vars['loss_ml'] = loss_ml
target_vars['total_loss'] = loss_total
target_vars['loss_energy'] = loss_energy
target_vars['weights'] = weights
target_vars['gvs'] = gvs
target_vars['X'] = X
target_vars['Y'] = Y
target_vars['LABEL'] = LABEL
target_vars['LABEL_POS'] = LABEL_POS
target_vars['X_NOISE'] = X_NOISE
target_vars['energy_pos'] = energy_pos
target_vars['energy_start'] = energy_negs[0]
if len(x_grads) >= 1:
target_vars['x_grad'] = x_grads[-1]
target_vars['x_grad_first'] = x_grads[0]
else:
target_vars['x_grad'] = tf.zeros(1)
target_vars['x_grad_first'] = tf.zeros(1)
target_vars['x_mod'] = x_mod
target_vars['x_off'] = x_off
target_vars['temp'] = temp
target_vars['energy_neg'] = energy_neg
target_vars['test_x_mod'] = test_x_mod
target_vars['eps_begin'] = eps_begin
if FLAGS.train:
grads = average_gradients(tower_grads)
train_op = optimizer.apply_gradients(grads)
target_vars['train_op'] = train_op
config = tf.ConfigProto()
if hvd.size() > 1:
config.gpu_options.visible_device_list = str(hvd.local_rank())
sess = tf.Session(config=config)
saver = loader = tf.train.Saver(
max_to_keep=30, keep_checkpoint_every_n_hours=6)
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Model has a total of {} parameters".format(total_parameters))
sess.run(tf.global_variables_initializer())
resume_itr = 0
if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
# saver.restore(sess, model_file)
optimistic_restore(sess, model_file)
sess.run(hvd.broadcast_global_variables(0))
print("Initializing variables...")
print("Start broadcast")
print("End broadcast")
if FLAGS.train:
train(target_vars, saver, sess,
logger, data_loader, resume_itr,
logdir)
test(target_vars, saver, sess, logger, data_loader)
if __name__ == "__main__":
main()