import tensorflow as tf import numpy as np from tensorflow.python.platform import flags from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128 import os.path as osp import os from baselines.logger import TensorBoardOutputFormat from utils import average_gradients, ReplayBuffer, optimistic_restore from tqdm import tqdm import random from torch.utils.data import DataLoader import time as time from io import StringIO from tensorflow.core.util import event_pb2 import torch import numpy as np from custom_adam import AdamOptimizer from scipy.misc import imsave import matplotlib.pyplot as plt from hmc import hmc from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() import horovod.tensorflow as hvd hvd.init() from inception import get_inception_score torch.manual_seed(hvd.rank()) np.random.seed(hvd.rank()) tf.set_random_seed(hvd.rank()) FLAGS = flags.FLAGS # Dataset Options flags.DEFINE_string('datasource', 'random', 'initialization for chains, either random or default (decorruption)') flags.DEFINE_string('dataset','mnist', 'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)') flags.DEFINE_integer('batch_size', 256, 'Size of inputs') flags.DEFINE_bool('single', False, 'whether to debug by training on a single image') flags.DEFINE_integer('data_workers', 4, 'Number of different data workers to load data in parallel') # General Experiment Settings flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') flags.DEFINE_string('exp', 'default', 'name of experiments') flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches') flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches') flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches') flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from') flags.DEFINE_bool('train', True, 'whether to train or test') flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on') flags.DEFINE_float('lr', 3e-4, 'Learning for training') flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on') # EBM Specific Experiments Settings flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients') flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training') flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models') flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels') flags.DEFINE_integer('temperature', 1, 'Temperature for energy function') flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),' 'logsumexp(more stable)' 'softplus(most stable)') flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss') # Setting for MCMC sampling flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images') flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection') flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training') flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent') flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.') flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models') flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC') flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead') # Architecture Settings flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets') flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') flags.DEFINE_bool('large_model', False, 'whether to use a large model') flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network') flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network') # Dataset settings flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images') flags.DEFINE_bool('augment', False, 'whether to augmentations to images') flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box') # Dsprites specific experiments flags.DEFINE_bool('cond_shape', False, 'condition of shape type') flags.DEFINE_bool('cond_size', False, 'condition of shape size') flags.DEFINE_bool('cond_pos', False, 'condition of position loc') flags.DEFINE_bool('cond_rot', False, 'condition of rot') FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale FLAGS.batch_size *= FLAGS.num_gpus print("{} batch size".format(FLAGS.batch_size)) def compress_x_mod(x_mod): x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8) return x_mod def decompress_x_mod(x_mod): x_mod = x_mod / 256 * FLAGS.rescale + \ np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape) return x_mod def make_image(tensor): """Convert an numpy representation image to Image protobuf""" from PIL import Image if len(tensor.shape) == 4: _, height, width, channel = tensor.shape elif len(tensor.shape) == 3: height, width, channel = tensor.shape elif len(tensor.shape) == 2: height, width = tensor.shape channel = 1 tensor = tensor.astype(np.uint8) image = Image.fromarray(tensor) import io output = io.BytesIO() image.save(output, format='PNG') image_string = output.getvalue() output.close() return tf.Summary.Image(height=height, width=width, colorspace=channel, encoded_image_string=image_string) def log_image(im, logger, tag, step=0): im = make_image(im) summary = [tf.Summary.Value(tag=tag, image=im)] summary = tf.Summary(value=summary) event = event_pb2.Event(summary=summary) event.step = step logger.writer.WriteEvent(event) logger.writer.Flush() def rescale_im(image): image = np.clip(image, 0, FLAGS.rescale) if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites': return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8) else: return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8) def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir): X = target_vars['X'] Y = target_vars['Y'] X_NOISE = target_vars['X_NOISE'] train_op = target_vars['train_op'] energy_pos = target_vars['energy_pos'] energy_neg = target_vars['energy_neg'] loss_energy = target_vars['loss_energy'] loss_ml = target_vars['loss_ml'] loss_total = target_vars['total_loss'] gvs = target_vars['gvs'] x_grad = target_vars['x_grad'] x_grad_first = target_vars['x_grad_first'] x_off = target_vars['x_off'] temp = target_vars['temp'] x_mod = target_vars['x_mod'] LABEL = target_vars['LABEL'] LABEL_POS = target_vars['LABEL_POS'] weights = target_vars['weights'] test_x_mod = target_vars['test_x_mod'] eps = target_vars['eps_begin'] label_ent = target_vars['label_ent'] if FLAGS.use_attention: gamma = weights[0]['atten']['gamma'] else: gamma = tf.zeros(1) val_output = [test_x_mod] gvs_dict = dict(gvs) log_output = [ train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, *gvs_dict.keys()] output = [train_op, x_mod] replay_buffer = ReplayBuffer(10000) itr = resume_iter x_mod = None gd_steps = 1 dataloader_iterator = iter(dataloader) best_inception = 0.0 for epoch in range(FLAGS.epoch_num): for data_corrupt, data, label in dataloader: data_corrupt = data_corrupt_init = data_corrupt.numpy() data_corrupt_init = data_corrupt.copy() data = data.numpy() label = label.numpy() label_init = label.copy() if FLAGS.mixup: idx = np.random.permutation(data.shape[0]) lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1)) data = data * lam + data[idx] * (1 - lam) if FLAGS.replay_batch and (x_mod is not None): replay_buffer.add(compress_x_mod(x_mod)) if len(replay_buffer) > FLAGS.batch_size: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = ( np.random.uniform( 0, FLAGS.rescale, FLAGS.batch_size) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.pcd: if x_mod is not None: data_corrupt = x_mod feed_dict = {X_NOISE: data_corrupt, X: data, Y: label} if FLAGS.cclass: feed_dict[LABEL] = label feed_dict[LABEL_POS] = label_init if itr % FLAGS.log_interval == 0: _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \ grads = sess.run(log_output, feed_dict) kvs = {} kvs['e_pos'] = e_pos.mean() kvs['e_pos_std'] = e_pos.std() kvs['e_neg'] = e_neg.mean() kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg'] kvs['e_neg_std'] = e_neg.std() kvs['temp'] = temp kvs['loss_e'] = loss_e.mean() kvs['eps'] = eps.mean() kvs['label_ent'] = label_ent kvs['loss_ml'] = loss_ml.mean() kvs['loss_total'] = loss_total.mean() kvs['x_grad'] = np.abs(x_grad).mean() kvs['x_grad_first'] = np.abs(x_grad_first).mean() kvs['x_off'] = x_off.mean() kvs['iter'] = itr kvs['gamma'] = gamma for v, k in zip(grads, [v.name for v in gvs_dict.values()]): kvs[k] = np.abs(v).max() string = "Obtained a total of " for key, value in kvs.items(): string += "{}: {}, ".format(key, value) if hvd.rank() == 0: print(string) logger.writekvs(kvs) else: _, x_mod = sess.run(output, feed_dict) if itr % FLAGS.save_interval == 0 and hvd.rank() == 0: saver.save( sess, osp.join( FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d': try_im = x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image( new_im, logger, 'train_gen_{}'.format(itr), step=i) test_im = x_mod try: data_corrupt, data, label = next(dataloader_iterator) except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt = data_corrupt.numpy() if FLAGS.replay_batch and ( x_mod is not None) and len(replay_buffer) > 0: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = ( np.random.uniform( 0, 1, (FLAGS.batch_size)) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull': n = 128 if FLAGS.dataset == "imagenetfull": n = 32 if len(replay_buffer) > n: data_corrupt = decompress_x_mod(replay_buffer.sample(n)) elif FLAGS.dataset == 'imagenetfull': data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 128, 128, 3)) else: data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 32, 32, 3)) if FLAGS.dataset == 'cifar10': label = np.eye(10)[np.random.randint(0, 10, (n))] else: label = np.eye(1000)[ np.random.randint( 0, 1000, (n))] feed_dict[X_NOISE] = data_corrupt feed_dict[X] = data if FLAGS.cclass: feed_dict[LABEL] = label test_x_mod = sess.run(val_output, feed_dict) try_im = test_x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data.numpy()) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image( new_im, logger, 'val_gen_{}'.format(itr), step=i) score, std = get_inception_score(list(try_im), splits=1) print( "Inception score of {} with std of {}".format( score, std)) kvs = {} kvs['inception_score'] = score kvs['inception_score_std'] = std logger.writekvs(kvs) if score > best_inception: best_inception = score saver.save( sess, osp.join( FLAGS.logdir, FLAGS.exp, 'model_best')) if itr > 60000 and FLAGS.dataset == "mnist": assert False itr += 1 saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) cifar10_map = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'} def test(target_vars, saver, sess, logger, dataloader): X_NOISE = target_vars['X_NOISE'] X = target_vars['X'] Y = target_vars['Y'] LABEL = target_vars['LABEL'] energy_start = target_vars['energy_start'] x_mod = target_vars['x_mod'] x_mod = target_vars['test_x_mod'] energy_neg = target_vars['energy_neg'] np.random.seed(1) random.seed(1) output = [x_mod, energy_start, energy_neg] dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy() orig_im = try_im = data_corrupt if FLAGS.cclass: try_im, energy_orig, energy = sess.run( output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label}) else: try_im, energy_orig, energy = sess.run( output, {X_NOISE: orig_im, Y: label[0:1]}) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im) actual_im = rescale_im(data) for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate( zip(orig_im, energy_orig, try_im, energy, label, actual_im)): label_i = np.array(label_i) shape = im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im if FLAGS.cclass: label_i = np.where(label_i == 1)[0][0] if FLAGS.dataset == 'cifar10': log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format( i, energy_i[0], energy[0], cifar10_map[label_i]), step=i) else: log_image( new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format( i, energy_i[0], energy[0], label_i), step=i) else: log_image( new_im, logger, '{}_{:.4f}_now_{:.4f}'.format( i, energy_i[0], energy[0]), step=i) test_ims = list(try_im) real_ims = list(actual_im) for i in tqdm(range(50000 // FLAGS.batch_size + 1)): try: data_corrupt, data, label = dataloader_iterator.next() except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = dataloader_iterator.next() data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy() if FLAGS.cclass: try_im, energy_orig, energy = sess.run( output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label}) else: try_im, energy_orig, energy = sess.run( output, {X_NOISE: data_corrupt, Y: label[0:1]}) try_im = rescale_im(try_im) real_im = rescale_im(data) test_ims.extend(list(try_im)) real_ims.extend(list(real_im)) score, std = get_inception_score(test_ims) print("Inception score of {} with std of {}".format(score, std)) def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None LABEL = None print("Loading data...") if FLAGS.dataset == 'cifar10': dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger( num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider( num_channels=channel_num, num_filters=192) else: model = ResNet32( num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'imagenet': dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet32Wider( num_channels=channel_num, num_filters=256) elif FLAGS.dataset == 'imagenetfull': channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet128( num_channels=channel_num, num_filters=64) elif FLAGS.dataset == 'mnist': dataset = Mnist(rescale=FLAGS.rescale) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) model = MnistNet( num_channels=channel_num, num_filters=FLAGS.num_filters) elif FLAGS.dataset == 'dsprites': dataset = DSprites( cond_shape=FLAGS.cond_shape, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) if FLAGS.dpos_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.dsize_only: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.drot_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.cond_shape: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_rot: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) model = DspritesNet( num_channels=channel_num, num_filters=FLAGS.num_filters, cond_size=FLAGS.cond_size, cond_shape=FLAGS.cond_shape, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) print("Done loading...") if FLAGS.dataset == "imagenetfull": # In the case of full imagenet, use custom_tensorflow dataloader data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale) else: data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = [model.construct_weights('context_0')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable( tf.convert_to_tensor( np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10)), dtype=tf.float32), trainable=False, dtype=tf.float32) x_split = tf.tile( tf.reshape( X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward( x_split, weights[0], label=label_tensor, stop_at_grad=False) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp( energy_pos_full, axis=1, keepdims=True) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax(-energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward( X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([model.forward(tf.stop_gradient( x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat( [model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True)], axis=0) x_grad, label_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver( max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader) if __name__ == "__main__": main()