import os import os.path as osp import random import horovod.tensorflow as hvd import numpy as np import tensorflow as tf import torch from baselines.logger import TensorBoardOutputFormat from custom_adam import AdamOptimizer from data import Cifar10, DSprites, Imagenet, Mnist, TFImagenetLoader from hmc import hmc from inception import get_inception_score from models import ( DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128, ) from mpi4py import MPI from tensorflow.core.util import event_pb2 from tensorflow.python.platform import flags from torch.utils.data import DataLoader from tqdm import tqdm from utils import ReplayBuffer, average_gradients, optimistic_restore comm = MPI.COMM_WORLD rank = comm.Get_rank() hvd.init() torch.manual_seed(hvd.rank()) np.random.seed(hvd.rank()) tf.set_random_seed(hvd.rank()) FLAGS = flags.FLAGS # Dataset Options flags.DEFINE_string( "datasource", "random", "initialization for chains, either random or default (decorruption)", ) flags.DEFINE_string( "dataset", "mnist", "dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)" ) flags.DEFINE_integer("batch_size", 256, "Size of inputs") flags.DEFINE_bool("single", False, "whether to debug by training on a single image") flags.DEFINE_integer( "data_workers", 4, "Number of different data workers to load data in parallel" ) # General Experiment Settings flags.DEFINE_string( "logdir", "cachedir", "location where log of experiments will be stored" ) flags.DEFINE_string("exp", "default", "name of experiments") flags.DEFINE_integer("log_interval", 10, "log outputs every so many batches") flags.DEFINE_integer("save_interval", 1000, "save outputs every so many batches") flags.DEFINE_integer("test_interval", 1000, "evaluate outputs every so many batches") flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from") flags.DEFINE_bool("train", True, "whether to train or test") flags.DEFINE_integer("epoch_num", 10000, "Number of Epochs to train on") flags.DEFINE_float("lr", 3e-4, "Learning for training") flags.DEFINE_integer("num_gpus", 1, "number of gpus to train on") # EBM Specific Experiments Settings flags.DEFINE_float("ml_coeff", 1.0, "Maximum Likelihood Coefficients") flags.DEFINE_float("l2_coeff", 1.0, "L2 Penalty training") flags.DEFINE_bool("cclass", False, "Whether to conditional training in models") flags.DEFINE_bool( "model_cclass", False, "use unsupervised clustering to infer fake labels" ) flags.DEFINE_integer("temperature", 1, "Temperature for energy function") flags.DEFINE_string( "objective", "cd", "use either contrastive divergence objective(least stable)," "logsumexp(more stable)" "softplus(most stable)", ) flags.DEFINE_bool("zero_kl", False, "whether to zero out the kl loss") # Setting for MCMC sampling flags.DEFINE_float("proj_norm", 0.0, "Maximum change of input images") flags.DEFINE_string("proj_norm_type", "li", "Either li or l2 ball projection") flags.DEFINE_integer("num_steps", 20, "Steps of gradient descent for training") flags.DEFINE_float("step_lr", 1.0, "Size of steps for gradient descent") flags.DEFINE_bool( "replay_batch", False, "Use MCMC chains initialized from a replay buffer." ) flags.DEFINE_bool("hmc", False, "Whether to use HMC sampling to train models") flags.DEFINE_float("noise_scale", 1.0, "Relative amount of noise for MCMC") flags.DEFINE_bool("pcd", False, "whether to use pcd training instead") # Architecture Settings flags.DEFINE_integer("num_filters", 64, "number of filters for conv nets") flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights") flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network") flags.DEFINE_bool("large_model", False, "whether to use a large model") flags.DEFINE_bool("larger_model", False, "Deeper ResNet32 Network") flags.DEFINE_bool("wider_model", False, "Wider ResNet32 Network") # Dataset settings flags.DEFINE_bool("mixup", False, "whether to add mixup to training images") flags.DEFINE_bool("augment", False, "whether to augmentations to images") flags.DEFINE_float("rescale", 1.0, "Factor to rescale inputs from 0-1 box") # Dsprites specific experiments flags.DEFINE_bool("cond_shape", False, "condition of shape type") flags.DEFINE_bool("cond_size", False, "condition of shape size") flags.DEFINE_bool("cond_pos", False, "condition of position loc") flags.DEFINE_bool("cond_rot", False, "condition of rot") FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale FLAGS.batch_size *= FLAGS.num_gpus print("{} batch size".format(FLAGS.batch_size)) def compress_x_mod(x_mod): x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8) return x_mod def decompress_x_mod(x_mod): x_mod = x_mod / 256 * FLAGS.rescale + np.random.uniform( 0, 1 / 256 * FLAGS.rescale, x_mod.shape ) return x_mod def make_image(tensor): """Convert an numpy representation image to Image protobuf""" from PIL import Image if len(tensor.shape) == 4: _, height, width, channel = tensor.shape elif len(tensor.shape) == 3: height, width, channel = tensor.shape elif len(tensor.shape) == 2: height, width = tensor.shape channel = 1 tensor = tensor.astype(np.uint8) image = Image.fromarray(tensor) import io output = io.BytesIO() image.save(output, format="PNG") image_string = output.getvalue() output.close() return tf.Summary.Image( height=height, width=width, colorspace=channel, encoded_image_string=image_string, ) def log_image(im, logger, tag, step=0): im = make_image(im) summary = [tf.Summary.Value(tag=tag, image=im)] summary = tf.Summary(value=summary) event = event_pb2.Event(summary=summary) event.step = step logger.writer.WriteEvent(event) logger.writer.Flush() def rescale_im(image): image = np.clip(image, 0, FLAGS.rescale) if FLAGS.dataset == "mnist" or FLAGS.dataset == "dsprites": return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype( np.uint8 ) else: return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8) def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir): X = target_vars["X"] Y = target_vars["Y"] X_NOISE = target_vars["X_NOISE"] train_op = target_vars["train_op"] energy_pos = target_vars["energy_pos"] energy_neg = target_vars["energy_neg"] loss_energy = target_vars["loss_energy"] loss_ml = target_vars["loss_ml"] loss_total = target_vars["total_loss"] gvs = target_vars["gvs"] x_grad = target_vars["x_grad"] x_grad_first = target_vars["x_grad_first"] x_off = target_vars["x_off"] temp = target_vars["temp"] x_mod = target_vars["x_mod"] LABEL = target_vars["LABEL"] LABEL_POS = target_vars["LABEL_POS"] weights = target_vars["weights"] test_x_mod = target_vars["test_x_mod"] eps = target_vars["eps_begin"] label_ent = target_vars["label_ent"] if FLAGS.use_attention: gamma = weights[0]["atten"]["gamma"] else: gamma = tf.zeros(1) val_output = [test_x_mod] gvs_dict = dict(gvs) log_output = [ train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, *gvs_dict.keys(), ] output = [train_op, x_mod] replay_buffer = ReplayBuffer(10000) itr = resume_iter x_mod = None dataloader_iterator = iter(dataloader) best_inception = 0.0 for epoch in range(FLAGS.epoch_num): for data_corrupt, data, label in dataloader: data_corrupt = data_corrupt_init = data_corrupt.numpy() data_corrupt.copy() data = data.numpy() label = label.numpy() label_init = label.copy() if FLAGS.mixup: idx = np.random.permutation(data.shape[0]) lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1)) data = data * lam + data[idx] * (1 - lam) if FLAGS.replay_batch and (x_mod is not None): replay_buffer.add(compress_x_mod(x_mod)) if len(replay_buffer) > FLAGS.batch_size: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = ( np.random.uniform(0, FLAGS.rescale, FLAGS.batch_size) > 0.05 ) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.pcd: if x_mod is not None: data_corrupt = x_mod feed_dict = {X_NOISE: data_corrupt, X: data, Y: label} if FLAGS.cclass: feed_dict[LABEL] = label feed_dict[LABEL_POS] = label_init if itr % FLAGS.log_interval == 0: ( _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, *grads, ) = sess.run(log_output, feed_dict) kvs = {} kvs["e_pos"] = e_pos.mean() kvs["e_pos_std"] = e_pos.std() kvs["e_neg"] = e_neg.mean() kvs["e_diff"] = kvs["e_pos"] - kvs["e_neg"] kvs["e_neg_std"] = e_neg.std() kvs["temp"] = temp kvs["loss_e"] = loss_e.mean() kvs["eps"] = eps.mean() kvs["label_ent"] = label_ent kvs["loss_ml"] = loss_ml.mean() kvs["loss_total"] = loss_total.mean() kvs["x_grad"] = np.abs(x_grad).mean() kvs["x_grad_first"] = np.abs(x_grad_first).mean() kvs["x_off"] = x_off.mean() kvs["iter"] = itr kvs["gamma"] = gamma for v, k in zip(grads, [v.name for v in gvs_dict.values()]): kvs[k] = np.abs(v).max() string = "Obtained a total of " for key, value in kvs.items(): string += "{}: {}, ".format(key, value) if hvd.rank() == 0: print(string) logger.writekvs(kvs) else: _, x_mod = sess.run(output, feed_dict) if itr % FLAGS.save_interval == 0 and hvd.rank() == 0: saver.save( sess, osp.join(FLAGS.logdir, FLAGS.exp, "model_{}".format(itr)) ) if ( itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != "2d" ): try_im = x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im) ): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size : 2 * size] = t_im new_im[:, 2 * size :] = actual_im_i log_image(new_im, logger, "train_gen_{}".format(itr), step=i) try: data_corrupt, data, label = next(dataloader_iterator) except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt = data_corrupt.numpy() if ( FLAGS.replay_batch and (x_mod is not None) and len(replay_buffer) > 0 ): replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.05 data_corrupt[replay_mask] = replay_batch[replay_mask] if ( FLAGS.dataset == "cifar10" or FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull" ): n = 128 if FLAGS.dataset == "imagenetfull": n = 32 if len(replay_buffer) > n: data_corrupt = decompress_x_mod(replay_buffer.sample(n)) elif FLAGS.dataset == "imagenetfull": data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 128, 128, 3) ) else: data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 32, 32, 3) ) if FLAGS.dataset == "cifar10": label = np.eye(10)[np.random.randint(0, 10, (n))] else: label = np.eye(1000)[np.random.randint(0, 1000, (n))] feed_dict[X_NOISE] = data_corrupt feed_dict[X] = data if FLAGS.cclass: feed_dict[LABEL] = label test_x_mod = sess.run(val_output, feed_dict) try_im = test_x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data.numpy()) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im) ): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size : 2 * size] = t_im new_im[:, 2 * size :] = actual_im_i log_image(new_im, logger, "val_gen_{}".format(itr), step=i) score, std = get_inception_score(list(try_im), splits=1) print("Inception score of {} with std of {}".format(score, std)) kvs = {} kvs["inception_score"] = score kvs["inception_score_std"] = std logger.writekvs(kvs) if score > best_inception: best_inception = score saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, "model_best")) if itr > 60000 and FLAGS.dataset == "mnist": assert False itr += 1 saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, "model_{}".format(itr))) cifar10_map = { 0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck", } def test(target_vars, saver, sess, logger, dataloader): X_NOISE = target_vars["X_NOISE"] target_vars["X"] Y = target_vars["Y"] LABEL = target_vars["LABEL"] energy_start = target_vars["energy_start"] x_mod = target_vars["x_mod"] x_mod = target_vars["test_x_mod"] energy_neg = target_vars["energy_neg"] np.random.seed(1) random.seed(1) output = [x_mod, energy_start, energy_neg] dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy() orig_im = try_im = data_corrupt if FLAGS.cclass: try_im, energy_orig, energy = sess.run( output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label} ) else: try_im, energy_orig, energy = sess.run( output, {X_NOISE: orig_im, Y: label[0:1]} ) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im) actual_im = rescale_im(data) for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate( zip(orig_im, energy_orig, try_im, energy, label, actual_im) ): label_i = np.array(label_i) shape = im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size : 2 * size] = t_im if FLAGS.cclass: label_i = np.where(label_i == 1)[0][0] if FLAGS.dataset == "cifar10": log_image( new_im, logger, "{}_{:.4f}_now_{:.4f}_{}".format( i, energy_i[0], energy[0], cifar10_map[label_i] ), step=i, ) else: log_image( new_im, logger, "{}_{:.4f}_now_{:.4f}_{}".format( i, energy_i[0], energy[0], label_i ), step=i, ) else: log_image( new_im, logger, "{}_{:.4f}_now_{:.4f}".format(i, energy_i[0], energy[0]), step=i, ) test_ims = list(try_im) real_ims = list(actual_im) for i in tqdm(range(50000 // FLAGS.batch_size + 1)): try: data_corrupt, data, label = dataloader_iterator.next() except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = dataloader_iterator.next() data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy() if FLAGS.cclass: try_im, energy_orig, energy = sess.run( output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label} ) else: try_im, energy_orig, energy = sess.run( output, {X_NOISE: data_corrupt, Y: label[0:1]} ) try_im = rescale_im(try_im) real_im = rescale_im(data) test_ims.extend(list(try_im)) real_ims.extend(list(real_im)) score, std = get_inception_score(test_ims) print("Inception score of {} with std of {}".format(score, std)) def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None LABEL = None print("Loading data...") if FLAGS.dataset == "cifar10": dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large(num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger(num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_channels=channel_num, num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == "imagenet": dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet32Wider(num_channels=channel_num, num_filters=256) elif FLAGS.dataset == "imagenetfull": channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet128(num_channels=channel_num, num_filters=64) elif FLAGS.dataset == "mnist": dataset = Mnist(rescale=FLAGS.rescale) channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) model = MnistNet(num_channels=channel_num, num_filters=FLAGS.num_filters) elif FLAGS.dataset == "dsprites": dataset = DSprites( cond_shape=FLAGS.cond_shape, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot, ) channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) if FLAGS.dpos_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.dsize_only: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.drot_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.cond_shape: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_rot: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) model = DspritesNet( num_channels=channel_num, num_filters=FLAGS.num_filters, cond_size=FLAGS.cond_size, cond_shape=FLAGS.cond_shape, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot, ) print("Done loading...") if FLAGS.dataset == "imagenetfull": # In the case of full imagenet, use custom_tensorflow dataloader data_loader = TFImagenetLoader( "train", FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale ) else: data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True, ) FLAGS.batch_size weights = [model.construct_weights("context_0")] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) list(LABEL_SPLIT) tower_grads = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable( tf.convert_to_tensor( np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10), ), dtype=tf.float32, ), trainable=False, dtype=tf.float32, ) x_split = tf.tile( tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1) ) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward( x_split, weights[0], label=label_tensor, stop_at_grad=False ) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp( energy_pos_full, axis=1, keepdims=True ) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax( -energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1, ) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward( X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False ) ] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] energy_negs.extend( [ model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True, ) ] ) eps_begin = tf.zeros(1) steps = tf.constant(0) def c(i, x): return tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal( tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale, ) energy_noise = energy_start = tf.concat( [ model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True, ) ], axis=0, ) x_grad, label_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]] ) lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == "l2": x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == "li": x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * model.forward( x, weights[0], label=LABEL_SPLIT[j], reuse=True ) x_last = hmc(x_mod, 15.0, 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True ) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True, ) ) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean(tf.abs(x_mod[: tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True ) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars["label_ent"] = label_ent if FLAGS.train: if FLAGS.objective == "logsumexp": temp * energy_pos energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == "cd": pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == "softplus": loss_ml = FLAGS.ml_coeff * tf.nn.softplus( temp * (energy_pos - energy_neg) ) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + FLAGS.l2_coeff * ( tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))) ) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars["loss_ml"] = loss_ml target_vars["total_loss"] = loss_total target_vars["loss_energy"] = loss_energy target_vars["weights"] = weights target_vars["gvs"] = gvs target_vars["X"] = X target_vars["Y"] = Y target_vars["LABEL"] = LABEL target_vars["LABEL_POS"] = LABEL_POS target_vars["X_NOISE"] = X_NOISE target_vars["energy_pos"] = energy_pos target_vars["energy_start"] = energy_negs[0] if len(x_grads) >= 1: target_vars["x_grad"] = x_grads[-1] target_vars["x_grad_first"] = x_grads[0] else: target_vars["x_grad"] = tf.zeros(1) target_vars["x_grad_first"] = tf.zeros(1) target_vars["x_mod"] = x_mod target_vars["x_off"] = x_off target_vars["temp"] = temp target_vars["energy_neg"] = energy_neg target_vars["test_x_mod"] = test_x_mod target_vars["eps_begin"] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars["train_op"] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader) if __name__ == "__main__": main()