From ec2ff8be874435e01fda6e2c4867e25f11275ddd Mon Sep 17 00:00:00 2001 From: Steinkirch Date: Sun, 10 May 2020 22:32:26 -0700 Subject: [PATCH] Add adapted code from OpenAI ebs --- .gitignore | 5 + README.md | 124 +++++ ais.py | 249 +++++++++ custom_adam.py | 236 ++++++++ data.py | 546 ++++++++++++++++++ ebm_combine.py | 698 +++++++++++++++++++++++ ebm_sandbox.py | 981 ++++++++++++++++++++++++++++++++ fid.py | 292 ++++++++++ hmc.py | 129 +++++ imagenet_demo.py | 73 +++ imagenet_preprocessing.py | 337 +++++++++++ inception.py | 105 ++++ models.py | 622 +++++++++++++++++++++ requirements.txt | 17 + test_inception.py | 333 +++++++++++ train.py | 941 +++++++++++++++++++++++++++++++ utils.py | 1107 +++++++++++++++++++++++++++++++++++++ 17 files changed, 6795 insertions(+) create mode 100644 README.md create mode 100644 ais.py create mode 100644 custom_adam.py create mode 100644 data.py create mode 100644 ebm_combine.py create mode 100644 ebm_sandbox.py create mode 100644 fid.py create mode 100644 hmc.py create mode 100644 imagenet_demo.py create mode 100644 imagenet_preprocessing.py create mode 100644 inception.py create mode 100644 models.py create mode 100644 requirements.txt create mode 100644 test_inception.py create mode 100644 train.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore index b6e4761..ff673a1 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ dmypy.json # Pyre type checker .pyre/ + +# Custom +sandbox_cachedir/ +cachedir +results \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..cbe2188 --- /dev/null +++ b/README.md @@ -0,0 +1,124 @@ +# Implicit Generation and Generalization in Energy Based Models + +Code for [Implicit Generation and Generalization in Energy Based Models](https://arxiv.org/pdf/1903.08689.pdf). Blog post can be found [here](https://openai.com/blog/energy-based-models/) and website with pretrained models can be found [here](https://sites.google.com/view/igebm/home). + +## Requirements + +To install the prerequisites for the project run +``` +pip install -r requirements.txt +mkdir sandbox_cachedir +``` + +Download all [pretrained models](https://sites.google.com/view/igebm/home) and unzip into the folder cachedir. + +## Download Datasets + +For MNIST and CIFAR-10 datasets, the code will directly download the data. + +For ImageNet 128x128 dataset, download the TFRecords of the Imagenet dataset by running the following command + +``` +for i in $(seq -f "%05g" 0 1023) +do + wget https://storage.googleapis.com/ebm_demo/data/imagenet/train-$i-of-01024 +done + +for i in $(seq -f "%05g" 0 127) +do + wget https://storage.googleapis.com/ebm_demo/data/imagenet/validation-$i-of-00128 +done + +wget https://storage.googleapis.com/ebm_demo/data/imagenet/index.json +``` + +For Imagenet 32x32 dataset, download the Imagenet 32x32 dataset and unzip by running the following command + +``` +wget https://storage.googleapis.com/ebm_demo/data/imagenet32/Imagenet32_train.zip +wget https://storage.googleapis.com/ebm_demo/data/imagenet32/Imagenet32_val.zip +``` + +For dSprites dataset, download the dataset by running + +``` +wget https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true +``` + +## Training + +To train on different datasets: + +For CIFAR-10 Unconditional + +``` +python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model +``` + +For CIFAR-10 Conditional + +``` +python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass +``` + +For ImageNet 32x32 Conditional + +``` +python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path= +``` + +For ImageNet 128x128 Conditional + +``` +python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir= +``` + +All code supports horovod execution, so model training can be increased substantially by using multiple different workers by running each command. +``` +mpiexec -n +``` + +## Demo + +The imagenet_demo.py file contains code to experiments with EBMs on conditional ImageNet 128x128. To generate a gif on sampling, you can run the command: + +``` +python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act +``` + +The ebm_sandbox.py file contains several different tasks that can be used to evaluate EBMs, which are defined by different settings of task flag in the file. For example, to visualize cross class mappings in CIFAR-10, you can run: + +``` +python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resume_iter=74700 +``` + + +## Generalization + +To test generalization to out of distribution classification for SVHN (with similar commands for other datasets) +``` +python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False +``` + +To test classification on CIFAR-10 using a conditional model under either L2 or Li perturbations +``` +python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd= --num_steps=10 --lival=
  • --wider_model +``` + + +## Concept Combination + +To train EBMs on conditional dSprites dataset, you can train each model seperately on each conditioned latent in cond_pos, cond_rot, cond_shape, cond_scale, with an example command given below. + +``` +python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass +``` + +Once models are trained, they can be sampled from jointly by running + +``` +python ebm_combine.py --task=conceptcombine --exp_size= --exp_shape= --exp_pos= --exp_rot= --resume_size= --resume_shape= --resume_rot= --resume_pos= +``` + + + diff --git a/ais.py b/ais.py new file mode 100644 index 0000000..c22cc6f --- /dev/null +++ b/ais.py @@ -0,0 +1,249 @@ +import tensorflow as tf +import math +from hmc import hmc +from tensorflow.python.platform import flags +from torch.utils.data import DataLoader +from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet +from data import Cifar10, Mnist, DSprites +from scipy.misc import logsumexp +from scipy.misc import imsave +from utils import optimistic_restore +import os.path as osp +import numpy as np +from tqdm import tqdm + +flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') +flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss') +flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') +flags.DEFINE_string('exp', 'default', 'name of experiments') +flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel') +flags.DEFINE_integer('batch_size', 16, 'Size of inputs') +flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from') + +flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions') +flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') +flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais') +flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian') +flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box') +flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model') +flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') +flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') +flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') +flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') +flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not') +flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not') +flags.DEFINE_bool('large_model', False, 'Use large model to evaluate') +flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate') +flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps') + +FLAGS = flags.FLAGS + +label_default = np.eye(10)[0:1, :] +label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32)) + + +def unscale_im(im): + return (255 * np.clip(im, 0, 1)).astype(np.uint8) + +def gauss_prob_log(x, prec=1.0): + + nh = float(np.prod([s.value for s in x.get_shape()[1:]])) + norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec)) + prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec + + return norm_constant_log + prob_density_log + + +def uniform_prob_log(x): + + return tf.zeros(1) + + +def model_prob_log(x, e_func, weights, temp): + if FLAGS.cclass: + batch_size = tf.shape(x)[0] + label_tiled = tf.tile(label_default, (batch_size, 1)) + e_raw = e_func.forward(x, weights, label=label_tiled) + else: + e_raw = e_func.forward(x, weights) + energy = tf.reduce_sum(e_raw, axis=[1]) + return -temp * energy + + +def bridge_prob_neg_log(alpha, x, e_func, weights, temp): + + if FLAGS.dataset == "gauss": + norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature) + else: + norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp) + # Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely + + + if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': + oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1]) + elif FLAGS.dataset == 'mnist': + oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2]) + else: + oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3]) + + return -norm_prob + oob_prob + + +def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10): + if FLAGS.dataset == "2d": + x = tf.placeholder(tf.float32, shape=(None, 2)) + elif FLAGS.dataset == "gauss": + x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim)) + elif FLAGS.dataset == "mnist": + x = tf.placeholder(tf.float32, shape=(None, 28, 28)) + else: + x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3)) + + x_init = x + + alpha_prev = tf.placeholder(tf.float32, shape=()) + alpha_new = tf.placeholder(tf.float32, shape=()) + approx_lr = tf.placeholder(tf.float32, shape=()) + + chain_weights = tf.zeros(batch_size) + # for i in range(1, prop_dist+1): + # print("processing loop {}".format(i)) + # alpha_prev = (i-1) / prop_dist + # alpha_new = i / prop_dist + + prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp) + prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp) + + chain_weights = -prob_log_new_neg + prob_log_old_neg + # chain_weights = tf.Print(chain_weights, [chain_weights]) + + # Sample new x using HMC + def unorm_prob(x): + return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp) + + for j in range(1): + x = hmc(x, approx_lr, hmc_step, unorm_prob) + + return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr + + +def main(): + + # Initialize dataset + if FLAGS.dataset == 'cifar10': + dataset = Cifar10(train=False, rescale=FLAGS.rescale) + channel_num = 3 + dim_input = 32 * 32 * 3 + elif FLAGS.dataset == 'imagenet': + dataset = ImagenetClass() + channel_num = 3 + dim_input = 64 * 64 * 3 + elif FLAGS.dataset == 'mnist': + dataset = Mnist(train=False, rescale=FLAGS.rescale) + channel_num = 1 + dim_input = 28 * 28 * 1 + elif FLAGS.dataset == 'dsprites': + dataset = DSprites() + channel_num = 1 + dim_input = 64 * 64 * 1 + elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': + dataset = Box2D() + + dim_output = 1 + data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True) + + if FLAGS.dataset == 'mnist': + model = MnistNet(num_channels=channel_num) + elif FLAGS.dataset == 'cifar10': + if FLAGS.large_model: + model = ResNet32Large(num_filters=128) + elif FLAGS.wider_model: + model = ResNet32Wider(num_filters=192) + else: + model = ResNet32(num_channels=channel_num, num_filters=128) + elif FLAGS.dataset == 'dsprites': + model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters) + + weights = model.construct_weights('context_{}'.format(0)) + + config = tf.ConfigProto() + sess = tf.Session(config=config) + saver = loader = tf.train.Saver(max_to_keep=10) + + sess.run(tf.global_variables_initializer()) + logdir = osp.join(FLAGS.logdir, FLAGS.exp) + + model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) + resume_itr = FLAGS.resume_iter + + if FLAGS.resume_iter != "-1": + optimistic_restore(sess, model_file) + else: + print("WARNING, YOU ARE NOT LOADING A SAVE FILE") + # saver.restore(sess, model_file) + + chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature) + print("Finished constructing ancestral sample ...................") + + if FLAGS.dataset != "gauss": + comb_weights_cum = [] + batch_size = tf.shape(x_init)[0] + label_tiled = tf.tile(label_default, (batch_size, 1)) + e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled) + e_pos_list = [] + + for data_corrupt, data, label_gt in tqdm(data_loader): + e_pos = sess.run([e_compute], {x_init: data})[0] + e_pos_list.extend(list(e_pos)) + + print(len(e_pos_list)) + print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list)) + + if FLAGS.dataset == "2d": + alr = 0.0045 + elif FLAGS.dataset == "gauss": + alr = 0.0085 + elif FLAGS.dataset == "mnist": + alr = 0.0065 + #90 alr = 0.0035 + else: + # alr = 0.0125 + if FLAGS.rescale == 8: + alr = 0.0085 + else: + alr = 0.0045 +# + for i in range(1): + tot_weight = 0 + for j in tqdm(range(1, FLAGS.pdist+1)): + if j == 1: + if FLAGS.dataset == "cifar10": + x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)) + elif FLAGS.dataset == "gauss": + x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)) + elif FLAGS.dataset == "mnist": + x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)) + else: + x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2)) + + alpha_prev = (j-1) / FLAGS.pdist + alpha_new = j / FLAGS.pdist + cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))}) + tot_weight = tot_weight + cweight + + print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) + + tot_weight = 0 + + for j in tqdm(range(FLAGS.pdist, 0, -1)): + alpha_new = (j-1) / FLAGS.pdist + alpha_prev = j / FLAGS.pdist + cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))}) + tot_weight = tot_weight - cweight + + print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight)) + + + +if __name__ == "__main__": + main() diff --git a/custom_adam.py b/custom_adam.py new file mode 100644 index 0000000..71789fe --- /dev/null +++ b/custom_adam.py @@ -0,0 +1,236 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adam for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +import tensorflow as tf + + +@tf_export("train.AdamOptimizer") +class AdamOptimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, + use_locking=False, name="Adam"): + """Construct a new Adam optimizer. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and + `epsilon` can each be a callable that takes no arguments and returns the + actual value to use. This can be useful for changing these values across + different invocations of optimizer functions. + @end_compatibility + """ + super(AdamOptimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + # Created in SparseApply if needed. + self._updated_lr = None + + def _get_beta_accumulators(self): + with ops.init_scope(): + if context.executing_eagerly(): + graph = None + else: + graph = ops.get_default_graph() + return (self._get_non_slot_variable("beta1_power", graph=graph), + self._get_non_slot_variable("beta2_power", graph=graph)) + + def _create_slots(self, var_list): + # Create the beta1 and beta2 accumulators on the same device as the first + # variable. Sort the var_list to make sure this device is consistent across + # workers (these need to go on the same PS, otherwise some updates are + # silently ignored). + first_var = min(var_list, key=lambda x: x.name) + self._create_non_slot_variable(initial_value=self._beta1, + name="beta1_power", + colocate_with=first_var) + self._create_non_slot_variable(initial_value=self._beta2, + name="beta2_power", + colocate_with=first_var) + + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + + clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1 + grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds) + # Clip gradients by 3 std + return training_ops.apply_adam( + var, m, v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub(var, + lr * m_t / (v_sqrt + epsilon_t), + use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) + + def _finish(self, update_ops, name_scope): + # Update the power accumulators. + with ops.control_dependencies(update_ops): + beta1_power, beta2_power = self._get_beta_accumulators() + with ops.colocate_with(beta1_power): + update_beta1 = beta1_power.assign( + beta1_power * self._beta1_t, use_locking=self._use_locking) + update_beta2 = beta2_power.assign( + beta2_power * self._beta2_t, use_locking=self._use_locking) + return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], + name=name_scope) diff --git a/data.py b/data.py new file mode 100644 index 0000000..42d93b3 --- /dev/null +++ b/data.py @@ -0,0 +1,546 @@ +from tensorflow.python.platform import flags +from tensorflow.contrib.data.python.ops import batching, threadpool +import tensorflow as tf +import json +from torch.utils.data import Dataset +import pickle +import os.path as osp +import os +import numpy as np +import time +from scipy.misc import imread, imresize +from skimage.color import rgb2grey +from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder +from torchvision import transforms +from imagenet_preprocessing import ImagenetPreprocessor +import torch +import torchvision + +FLAGS = flags.FLAGS +ROOT_DIR = "./results" + +# Dataset Options +flags.DEFINE_string('dsprites_path', + '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', + 'path to dsprites characters') +flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image') +flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes') +flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes') +flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects') +flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects') +flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects') +flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images') + + +# Data augmentation options +flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image') +flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout') +flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout') +flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data') + + +def cutout(mask_color=(0, 0, 0)): + mask_size_half = FLAGS.cutout_mask_size // 2 + offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0 + + def _cutout(image): + image = np.asarray(image).copy() + + if np.random.random() > FLAGS.cutout_prob: + return image + + h, w = image.shape[:2] + + if FLAGS.cutout_inside: + cxmin, cxmax = mask_size_half, w + offset - mask_size_half + cymin, cymax = mask_size_half, h + offset - mask_size_half + else: + cxmin, cxmax = 0, w + offset + cymin, cymax = 0, h + offset + + cx = np.random.randint(cxmin, cxmax) + cy = np.random.randint(cymin, cymax) + xmin = cx - mask_size_half + ymin = cy - mask_size_half + xmax = xmin + FLAGS.cutout_mask_size + ymax = ymin + FLAGS.cutout_mask_size + xmin = max(0, xmin) + ymin = max(0, ymin) + xmax = min(w, xmax) + ymax = min(h, ymax) + image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None] + return image + + return _cutout + + +class TFImagenetLoader(Dataset): + + def __init__(self, split, batchsize, idx, num_workers, rescale=1): + IMAGENET_NUM_TRAIN_IMAGES = 1281167 + IMAGENET_NUM_VAL_IMAGES = 50000 + + self.rescale = rescale + + if split == "train": + im_length = IMAGENET_NUM_TRAIN_IMAGES + records_to_skip = im_length * idx // num_workers + records_to_read = im_length * (idx + 1) // num_workers - records_to_skip + else: + im_length = IMAGENET_NUM_VAL_IMAGES + + self.curr_sample = 0 + + index_path = osp.join(FLAGS.imagenet_datadir, 'index.json') + with open(index_path) as f: + metadata = json.load(f) + counts = metadata['record_counts'] + + if split == 'train': + file_names = list(sorted([x for x in counts.keys() if x.startswith('train')])) + + result_records_to_skip = None + files = [] + for filename in file_names: + records_in_file = counts[filename] + if records_to_skip >= records_in_file: + records_to_skip -= records_in_file + continue + elif records_to_read > 0: + if result_records_to_skip is None: + # Record the number to skip in the first file + result_records_to_skip = records_to_skip + files.append(filename) + records_to_read -= (records_in_file - records_to_skip) + records_to_skip = 0 + else: + break + else: + files = list(sorted([x for x in counts.keys() if x.startswith('validation')])) + + files = [osp.join(FLAGS.imagenet_datadir, x) for x in files] + preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess + + ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string) + ds = ds.apply(tf.data.TFRecordDataset) + ds = ds.take(im_length) + ds = ds.prefetch(buffer_size=FLAGS.batch_size) + ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000)) + ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4)) + ds = ds.prefetch(buffer_size=2) + + ds_iterator = ds.make_initializable_iterator() + labels, images = ds_iterator.get_next() + self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0) + self.labels = labels + + config = tf.ConfigProto(device_count = {'GPU': 0}) + sess = tf.Session(config=config) + sess.run(ds_iterator.initializer) + + self.im_length = im_length // batchsize + + self.sess = sess + + def __next__(self): + self.curr_sample += 1 + + sess = self.sess + + im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)) + label, im = sess.run([self.labels, self.images]) + im = im * self.rescale + label = np.eye(1000)[label.squeeze() - 1] + im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label) + return im_corrupt, im, label + + def __iter__(self): + return self + + def __len__(self): + return self.im_length + +class CelebA(Dataset): + + def __init__(self): + self.path = "/root/data/img_align_celeba" + self.ims = os.listdir(self.path) + self.ims = [osp.join(self.path, im) for im in self.ims] + + def __len__(self): + return len(self.ims) + + def __getitem__(self, index): + label = 1 + + if FLAGS.single: + index = 0 + + path = self.ims[index] + im = imread(path) + im = imresize(im, (32, 32)) + image_size = 32 + im = im / 255. + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) + elif FLAGS.datasource == 'random': + im_corrupt = np.random.uniform( + 0, 1, size=(image_size, image_size, 3)) + + return im_corrupt, im, label + + +class Cifar10(Dataset): + def __init__( + self, + train=True, + full=False, + augment=False, + noise=True, + rescale=1.0): + + if augment: + transform_list = [ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + ] + + if FLAGS.cutout: + transform_list.append(cutout()) + + transform = transforms.Compose(transform_list) + else: + transform = transforms.ToTensor() + + self.full = full + self.data = CIFAR10( + ROOT_DIR, + transform=transform, + train=train, + download=True) + self.test_data = CIFAR10( + ROOT_DIR, + transform=transform, + train=False, + download=True) + self.one_hot_map = np.eye(10) + self.noise = noise + self.rescale = rescale + + def __len__(self): + + if self.full: + return len(self.data) + len(self.test_data) + else: + return len(self.data) + + def __getitem__(self, index): + if not FLAGS.single: + if self.full: + if index >= len(self.data): + im, label = self.test_data[index - len(self.data)] + else: + im, label = self.data[index] + else: + im, label = self.data[index] + else: + im, label = self.data[0] + + im = np.transpose(im, (1, 2, 0)).numpy() + image_size = 32 + label = self.one_hot_map[label] + + im = im * 255 / 256 + + if self.noise: + im = im * self.rescale + \ + np.random.uniform(0, self.rescale * 1 / 256., im.shape) + + np.random.seed((index + int(time.time() * 1e7)) % 2**32) + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) + elif FLAGS.datasource == 'random': + im_corrupt = np.random.uniform( + 0.0, self.rescale, (image_size, image_size, 3)) + + return im_corrupt, im, label + + +class Cifar100(Dataset): + def __init__(self, train=True, augment=False): + + if augment: + transform_list = [ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + ] + + if FLAGS.cutout: + transform_list.append(cutout()) + + transform = transforms.Compose(transform_list) + else: + transform = transforms.ToTensor() + + self.data = CIFAR100( + "/root/cifar100", + transform=transform, + train=train, + download=True) + self.one_hot_map = np.eye(100) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + if not FLAGS.single: + im, label = self.data[index] + else: + im, label = self.data[0] + + im = np.transpose(im, (1, 2, 0)).numpy() + image_size = 32 + label = self.one_hot_map[label] + im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) + np.random.seed((index + int(time.time() * 1e7)) % 2**32) + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) + elif FLAGS.datasource == 'random': + im_corrupt = np.random.uniform( + 0.0, 1.0, (image_size, image_size, 3)) + + return im_corrupt, im, label + + +class Svhn(Dataset): + def __init__(self, train=True, augment=False): + + transform = transforms.ToTensor() + + self.data = SVHN("/root/svhn", transform=transform, download=True) + self.one_hot_map = np.eye(10) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + if not FLAGS.single: + im, label = self.data[index] + else: + em, label = self.data[0] + + im = np.transpose(im, (1, 2, 0)).numpy() + image_size = 32 + label = self.one_hot_map[label] + im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) + np.random.seed((index + int(time.time() * 1e7)) % 2**32) + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) + elif FLAGS.datasource == 'random': + im_corrupt = np.random.uniform( + 0.0, 1.0, (image_size, image_size, 3)) + + return im_corrupt, im, label + + +class Mnist(Dataset): + def __init__(self, train=True, rescale=1.0): + self.data = MNIST( + "/root/mnist", + transform=transforms.ToTensor(), + download=True, train=train) + self.labels = np.eye(10) + self.rescale = rescale + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + im, label = self.data[index] + label = self.labels[label] + im = im.squeeze() + # im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28)) + # im = im.numpy() / 2 + 0.2 + im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28)) + im = im * self.rescale + image_size = 28 + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size) + elif FLAGS.datasource == 'random': + im_corrupt = np.random.uniform(0, self.rescale, (28, 28)) + + return im_corrupt, im, label + + +class DSprites(Dataset): + def __init__( + self, + cond_size=False, + cond_shape=False, + cond_pos=False, + cond_rot=False): + dat = np.load(FLAGS.dsprites_path) + + if FLAGS.dshape_only: + l = dat['latents_values'] + mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 / + 31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) + self.data = np.tile(dat['imgs'][mask], (10000, 1, 1)) + self.label = np.tile(dat['latents_values'][mask], (10000, 1)) + self.label = self.label[:, 1:2] + elif FLAGS.dpos_only: + l = dat['latents_values'] + # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) + mask = (l[:, 1] == 1) & ( + l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) + self.data = np.tile(dat['imgs'][mask], (100, 1, 1)) + self.label = np.tile(dat['latents_values'][mask], (100, 1)) + self.label = self.label[:, 4:] + 0.5 + elif FLAGS.dsize_only: + l = dat['latents_values'] + # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) + mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 / + 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) + self.data = np.tile(dat['imgs'][mask], (10000, 1, 1)) + self.label = np.tile(dat['latents_values'][mask], (10000, 1)) + self.label = (self.label[:, 2:3]) + elif FLAGS.drot_only: + l = dat['latents_values'] + mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 / + 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) + self.data = np.tile(dat['imgs'][mask], (100, 1, 1)) + self.label = np.tile(dat['latents_values'][mask], (100, 1)) + self.label = (self.label[:, 3:4]) + self.label = np.concatenate( + [np.cos(self.label), np.sin(self.label)], axis=1) + elif FLAGS.dsprites_restrict: + l = dat['latents_values'] + mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39) + + self.data = dat['imgs'][mask] + self.label = dat['latents_values'][mask] + else: + self.data = dat['imgs'] + self.label = dat['latents_values'] + + if cond_size: + self.label = self.label[:, 2:3] + elif cond_shape: + self.label = self.label[:, 1:2] + elif cond_pos: + self.label = self.label[:, 4:] + elif cond_rot: + self.label = self.label[:, 3:4] + self.label = np.concatenate( + [np.cos(self.label), np.sin(self.label)], axis=1) + else: + self.label = self.label[:, 1:2] + + self.identity = np.eye(3) + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + im = self.data[index] + image_size = 64 + + if not ( + FLAGS.dpos_only or FLAGS.dsize_only) and ( + not FLAGS.cond_size) and ( + not FLAGS.cond_pos) and ( + not FLAGS.cond_rot) and ( + not FLAGS.drot_only): + label = self.identity[self.label[index].astype( + np.int32) - 1].squeeze() + else: + label = self.label[index] + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size) + elif FLAGS.datasource == 'random': + im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size) + + return im_corrupt, im, label + + +class Imagenet(Dataset): + def __init__(self, train=True, augment=False): + + if train: + for i in range(1, 11): + f = pickle.load( + open( + osp.join( + FLAGS.imagenet_path, + 'train_data_batch_{}'.format(i)), + 'rb')) + if i == 1: + labels = f['labels'] + data = f['data'] + else: + labels.extend(f['labels']) + data = np.vstack((data, f['data'])) + else: + f = pickle.load( + open( + osp.join( + FLAGS.imagenet_path, + 'val_data'), + 'rb')) + labels = f['labels'] + data = f['data'] + + self.labels = labels + self.data = data + self.one_hot_map = np.eye(1000) + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + if not FLAGS.single: + im, label = self.data[index], self.labels[index] + else: + im, label = self.data[0], self.labels[0] + + label -= 1 + + im = im.reshape((3, 32, 32)) / 255 + im = im.transpose((1, 2, 0)) + image_size = 32 + label = self.one_hot_map[label] + im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) + np.random.seed((index + int(time.time() * 1e7)) % 2**32) + + if FLAGS.datasource == 'default': + im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) + elif FLAGS.datasource == 'random': + im_corrupt = np.random.uniform( + 0.0, 1.0, (image_size, image_size, 3)) + + return im_corrupt, im, label + + +class Textures(Dataset): + def __init__(self, train=True, augment=False): + self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images") + + def __len__(self): + return 2 * len(self.dataset) + + def __getitem__(self, index): + idx = index % (len(self.dataset)) + im, label = self.dataset[idx] + + im = np.array(im)[:32, :32] / 255 + im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) + + return im, im, label diff --git a/ebm_combine.py b/ebm_combine.py new file mode 100644 index 0000000..c79a606 --- /dev/null +++ b/ebm_combine.py @@ -0,0 +1,698 @@ +import tensorflow as tf +import math +from tqdm import tqdm +from hmc import hmc +from tensorflow.python.platform import flags +from torch.utils.data import DataLoader, Dataset +from models import DspritesNet +from utils import optimistic_restore, ReplayBuffer +import os.path as osp +import numpy as np +from rl_algs.logger import TensorBoardOutputFormat +from scipy.misc import imsave +import os +from custom_adam import AdamOptimizer + +flags.DEFINE_integer('batch_size', 256, 'Size of inputs') +flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things') +flags.DEFINE_string('logdir', 'cachedir', 'directory for logging') +flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored') +flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') +flags.DEFINE_float('step_lr', 500, 'size of gradient descent size') +flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters') +flags.DEFINE_bool('cclass', True, 'not cclass') +flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons') +flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') +flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') +flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') +flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results') +flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') +flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.') +flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape') +flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape') + +# Conditions on which models to use +flags.DEFINE_bool('cond_pos', True, 'whether to condition on position') +flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation') +flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape') +flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale') + +flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments') +flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments') +flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments') +flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments') +flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume') +flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume') +flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume') +flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume') +flags.DEFINE_integer('break_steps', 300, 'steps to break') + +# Whether to train for gentest +flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions') + +FLAGS = flags.FLAGS + +class DSpritesGen(Dataset): + def __init__(self, data, latents, frac=0.0): + + l = latents + + if FLAGS.joint_shape: + mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) + elif FLAGS.joint_rot: + mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) + else: + mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1) + + mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) + + data_pos = data[mask_pos] + l_pos = l[mask_pos] + + data_size = data[mask_size] + l_size = l[mask_size] + + n = data_pos.shape[0] // data_size.shape[0] + + data_pos = np.tile(data_pos, (n, 1, 1)) + l_pos = np.tile(l_pos, (n, 1)) + + self.data = np.concatenate((data_pos, data_size), axis=0) + self.label = np.concatenate((l_pos, l_size), axis=0) + + mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)) + data_add = data[mask_neg] + l_add = l[mask_neg] + + perm_idx = np.random.permutation(data_add.shape[0]) + select_idx = perm_idx[:int(frac*perm_idx.shape[0])] + data_add = data_add[select_idx] + l_add = l_add[select_idx] + + self.data = np.concatenate((self.data, data_add), axis=0) + self.label = np.concatenate((self.label, l_add), axis=0) + + self.identity = np.eye(3) + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + im = self.data[index] + im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64) + + if FLAGS.joint_shape: + label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1] + elif FLAGS.joint_rot: + label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]) + else: + label_size = self.label[index, 2:3] + + label_pos = self.label[index, 4:] + + return (im_corrupt, im, label_size, label_pos) + + +def labeldiscover(sess, kvs, data, latents, save_exp_dir): + LABEL_SIZE = kvs['LABEL_SIZE'] + model_size = kvs['model_size'] + weight_size = kvs['weight_size'] + x_mod = kvs['X_NOISE'] + + label_output = LABEL_SIZE + for i in range(FLAGS.num_steps): + label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03) + e_noise = model_size.forward(x_mod, weight_size, label=label_output) + label_grad = tf.gradients(e_noise, [label_output])[0] + # label_grad = tf.Print(label_grad, [label_grad]) + label_output = label_output - 1.0 * label_grad + label_output = tf.clip_by_value(label_output, 0.5, 1.0) + + diffs = [] + for i in range(30): + s = i*FLAGS.batch_size + d = (i+1)*FLAGS.batch_size + data_i = data[s:d] + latent_i = latents[s:d] + latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1)) + + feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init} + size_pred = sess.run([label_output], feed_dict)[0] + size_gt = latent_i[:, 2:3] + + diffs.append(np.abs(size_pred - size_gt).mean()) + + print(np.array(diffs).mean()) + + +def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0): + # tf.reset_default_graph() + + if FLAGS.joint_shape: + model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5) + LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32) + else: + model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3) + LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) + + weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac)) + + X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32) + X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) + + X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL) + loss_sq = tf.reduce_mean(tf.square(X_out - X_label)) + + optimizer = AdamOptimizer(1e-3) + gvs = optimizer.compute_gradients(loss_sq) + gvs = [(k, v) for (k, v) in gvs if k is not None] + train_op = optimizer.apply_gradients(gvs) + + dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) + + datafull = data + + itr = 0 + saver = tf.train.Saver() + + vs = optimizer.variables() + sess.run(tf.global_variables_initializer()) + + if FLAGS.train: + for _ in range(5): + for data_corrupt, data, label_size, label_pos in tqdm(dataloader): + + data_corrupt = data_corrupt.numpy() + label_size, label_pos = label_size.numpy(), label_pos.numpy() + + data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters) + label_comb = np.concatenate([label_size, label_pos], axis=1) + + feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb} + + output = [loss_sq, train_op] + + loss, _ = sess.run(output, feed_dict=feed_dict) + + itr += 1 + + saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline')) + + saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline')) + + l = latents + + if FLAGS.joint_shape: + mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) + else: + mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) + + data_gen = datafull[mask_gen] + latents_gen = latents[mask_gen] + losses = [] + + for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)): + data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters) + + if FLAGS.joint_shape: + latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1] + latent_pos = latent[:, 4:6] + latent = np.concatenate([latent_size, latent_pos], axis=1) + feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat} + else: + feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat} + loss = sess.run([loss_sq], feed_dict=feed_dict)[0] + # print(loss) + losses.append(loss) + + print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac)) + + + data_try = data_gen[:10] + data_init = np.random.randn(10, 2*FLAGS.num_filters) + + if FLAGS.joint_shape: + latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1] + latent_pos = latents_gen[:10, 4:] + else: + latent_scale = latents_gen[:10, 2:3] + latent_pos = latents_gen[:10, 4:] + + latent_tot = np.concatenate([latent_scale, latent_pos], axis=1) + + feed_dict = {X_feed: data_init, LABEL: latent_tot} + x_output = sess.run([X_out], feed_dict=feed_dict)[0] + x_output = np.clip(x_output, 0, 1) + + im_name = "size_scale_combine_genbaseline.png" + + x_output_wrap = np.ones((10, 66, 66)) + data_try_wrap = np.ones((10, 66, 66)) + + x_output_wrap[:, 1:-1, 1:-1] = x_output + data_try_wrap[:, 1:-1, 1:-1] = data_try + + im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) + impath = osp.join(save_exp_dir, im_name) + imsave(impath, im_output) + print("Successfully saved images at {}".format(impath)) + + return np.mean(losses) + + +def gentest(sess, kvs, data, latents, save_exp_dir): + X_NOISE = kvs['X_NOISE'] + LABEL_SIZE = kvs['LABEL_SIZE'] + LABEL_SHAPE = kvs['LABEL_SHAPE'] + LABEL_POS = kvs['LABEL_POS'] + LABEL_ROT = kvs['LABEL_ROT'] + model_size = kvs['model_size'] + model_shape = kvs['model_shape'] + model_pos = kvs['model_pos'] + model_rot = kvs['model_rot'] + weight_size = kvs['weight_size'] + weight_shape = kvs['weight_shape'] + weight_pos = kvs['weight_pos'] + weight_rot = kvs['weight_rot'] + X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) + + datafull = data + # Test combination of generalization where we use slices of both training + x_final = X_NOISE + x_mod_size = X_NOISE + x_mod_pos = X_NOISE + + for i in range(FLAGS.num_steps): + + # use cond_pos + + energies = [] + x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) + e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS) + + # energies.append(e_noise) + x_grad = tf.gradients(e_noise, [x_final])[0] + x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) + x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad + x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) + + if FLAGS.joint_shape: + # use cond_shape + e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE) + elif FLAGS.joint_rot: + e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT) + else: + # use cond_size + e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE) + + # energies.append(e_noise) + # energy_stack = tf.concat(energies, axis=1) + # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1) + # energy_stack = tf.reduce_sum(energy_stack, axis=1) + + x_grad = tf.gradients(e_noise, [x_mod_pos])[0] + x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad + x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) + + # for x_mod_size + # use cond_size + # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE) + # x_grad = tf.gradients(e_noise, [x_mod_size])[0] + # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) + # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad + # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) + + # # use cond_pos + # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS) + # x_grad = tf.gradients(e_noise, [x_mod_size])[0] + # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) + # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad) + # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) + + x_mod = x_mod_pos + x_final = x_mod + + + if FLAGS.joint_shape: + loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \ + model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) + + energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \ + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) + + energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \ + model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) + elif FLAGS.joint_rot: + loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \ + model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) + + energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \ + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) + + energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \ + model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) + else: + loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \ + model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) + + energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \ + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) + + energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \ + model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) + + energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) + coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced)) + norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 + neg_loss = coeff * (-1*energy_neg) / norm_constant + + loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg) + loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) + + optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999) + gvs = optimizer.compute_gradients(loss_total) + gvs = [(k, v) for (k, v) in gvs if k is not None] + train_op = optimizer.apply_gradients(gvs) + + vs = optimizer.variables() + sess.run(tf.variables_initializer(vs)) + + dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) + + x_off = tf.reduce_mean(tf.square(x_mod - X)) + + itr = 0 + saver = tf.train.Saver() + x_mod = None + + + if FLAGS.train: + replay_buffer = ReplayBuffer(10000) + for _ in range(1): + + + for data_corrupt, data, label_size, label_pos in tqdm(dataloader): + data_corrupt = data_corrupt.numpy()[:, :, :] + data = data.numpy()[:, :, :] + + if x_mod is not None: + replay_buffer.add(x_mod) + replay_batch = replay_buffer.sample(FLAGS.batch_size) + replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95) + data_corrupt[replay_mask] = replay_batch[replay_mask] + + if FLAGS.joint_shape: + feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos} + elif FLAGS.joint_rot: + feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos} + else: + feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos} + + _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict) + itr += 1 + + if itr % 10 == 0: + print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr)) + + if itr == FLAGS.break_steps: + break + + + saver.save(sess, osp.join(save_exp_dir, 'model_gentest')) + + saver.restore(sess, osp.join(save_exp_dir, 'model_gentest')) + + l = latents + + if FLAGS.joint_shape: + mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) + elif FLAGS.joint_rot: + mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5) + else: + mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) + + data_gen = datafull[mask_gen] + latents_gen = latents[mask_gen] + + losses = [] + + for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)): + x = 0.5 + np.random.randn(*dat.shape) + + if FLAGS.joint_shape: + feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} + elif FLAGS.joint_rot: + feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} + else: + feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} + + for i in range(2): + x = sess.run([x_final], feed_dict=feed_dict)[0] + feed_dict[X_NOISE] = x + + loss = sess.run([x_off], feed_dict=feed_dict)[0] + losses.append(loss) + + print("Mean MSE loss of {} ".format(np.mean(losses))) + + data_try = data_gen[:10] + data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) + latent_scale = latents_gen[:10, 2:3] + latent_pos = latents_gen[:10, 4:] + + if FLAGS.joint_shape: + feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos} + elif FLAGS.joint_rot: + feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init} + else: + feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos} + + x_output = sess.run([x_final], feed_dict=feed_dict)[0] + + if FLAGS.joint_shape: + im_name = "size_shape_combine_gentest.png" + else: + im_name = "size_scale_combine_gentest.png" + + x_output_wrap = np.ones((10, 66, 66)) + data_try_wrap = np.ones((10, 66, 66)) + + x_output_wrap[:, 1:-1, 1:-1] = x_output + data_try_wrap[:, 1:-1, 1:-1] = data_try + + im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) + impath = osp.join(save_exp_dir, im_name) + imsave(impath, im_output) + print("Successfully saved images at {}".format(impath)) + + + +def conceptcombine(sess, kvs, data, latents, save_exp_dir): + X_NOISE = kvs['X_NOISE'] + LABEL_SIZE = kvs['LABEL_SIZE'] + LABEL_SHAPE = kvs['LABEL_SHAPE'] + LABEL_POS = kvs['LABEL_POS'] + LABEL_ROT = kvs['LABEL_ROT'] + model_size = kvs['model_size'] + model_shape = kvs['model_shape'] + model_pos = kvs['model_pos'] + model_rot = kvs['model_rot'] + weight_size = kvs['weight_size'] + weight_shape = kvs['weight_shape'] + weight_pos = kvs['weight_pos'] + weight_rot = kvs['weight_rot'] + + x_mod = X_NOISE + for i in range(FLAGS.num_steps): + + if FLAGS.cond_scale: + e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE) + x_grad = tf.gradients(e_noise, [x_mod])[0] + x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) + x_mod = x_mod - FLAGS.step_lr * x_grad + x_mod = tf.clip_by_value(x_mod, 0, 1) + + if FLAGS.cond_shape: + e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE) + x_grad = tf.gradients(e_noise, [x_mod])[0] + x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) + x_mod = x_mod - FLAGS.step_lr * x_grad + x_mod = tf.clip_by_value(x_mod, 0, 1) + + if FLAGS.cond_pos: + e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS) + x_grad = tf.gradients(e_noise, [x_mod])[0] + x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) + x_mod = x_mod - FLAGS.step_lr * x_grad + x_mod = tf.clip_by_value(x_mod, 0, 1) + + if FLAGS.cond_rot: + e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT) + x_grad = tf.gradients(e_noise, [x_mod])[0] + x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) + x_mod = x_mod - FLAGS.step_lr * x_grad + x_mod = tf.clip_by_value(x_mod, 0, 1) + + print("Finished constructing loop {}".format(i)) + + x_final = x_mod + + data_try = data[:10] + data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) + label_scale = latents[:10, 2:3] + label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)] + label_rot = latents[:10, 3:4] + label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1) + label_pos = latents[:10, 4:] + + feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos, + LABEL_ROT: label_rot} + x_out = sess.run([x_final], feed_dict)[0] + + im_name = "im" + + if FLAGS.cond_scale: + im_name += "_condscale" + + if FLAGS.cond_shape: + im_name += "_condshape" + + if FLAGS.cond_pos: + im_name += "_condpos" + + if FLAGS.cond_rot: + im_name += "_condrot" + + im_name += ".png" + + x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66)) + x_out_pad[:, 1:-1, 1:-1] = x_out + data_try_pad[:, 1:-1, 1:-1] = data_try + + im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2) + impath = osp.join(save_exp_dir, im_name) + imsave(impath, im_output) + print("Successfully saved images at {}".format(impath)) + +def main(): + data = np.load(FLAGS.dsprites_path)['imgs'] + l = latents = np.load(FLAGS.dsprites_path)['latents_values'] + + np.random.seed(1) + idx = np.random.permutation(data.shape[0]) + + data = data[idx] + latents = latents[idx] + + config = tf.ConfigProto() + sess = tf.Session(config=config) + + # Model 1 will be conditioned on size + model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True) + weight_size = model_size.construct_weights('context_0') + + # Model 2 will be conditioned on shape + model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True) + weight_shape = model_shape.construct_weights('context_1') + + # Model 3 will be conditioned on position + model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True) + weight_pos = model_pos.construct_weights('context_2') + + # Model 4 will be conditioned on rotation + model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True) + weight_rot = model_rot.construct_weights('context_3') + + sess.run(tf.global_variables_initializer()) + save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size)) + + v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0)) + v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list} + + if FLAGS.cond_scale: + saver = tf.train.Saver(v_map) + saver.restore(sess, save_path_size) + + save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape)) + + v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) + v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} + + if FLAGS.cond_shape: + saver = tf.train.Saver(v_map) + saver.restore(sess, save_path_shape) + + + save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos)) + v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) + v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} + saver = tf.train.Saver(v_map) + + if FLAGS.cond_pos: + saver.restore(sess, save_path_pos) + + + save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot)) + v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3)) + v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list} + saver = tf.train.Saver(v_map) + + if FLAGS.cond_rot: + saver.restore(sess, save_path_rot) + + X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) + LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32) + LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) + LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32) + + x_mod = X_NOISE + + kvs = {} + kvs['X_NOISE'] = X_NOISE + kvs['LABEL_SIZE'] = LABEL_SIZE + kvs['LABEL_SHAPE'] = LABEL_SHAPE + kvs['LABEL_POS'] = LABEL_POS + kvs['LABEL_ROT'] = LABEL_ROT + kvs['model_size'] = model_size + kvs['model_shape'] = model_shape + kvs['model_pos'] = model_pos + kvs['model_rot'] = model_rot + kvs['weight_size'] = weight_size + kvs['weight_shape'] = weight_shape + kvs['weight_pos'] = weight_pos + kvs['weight_rot'] = weight_rot + + save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape)) + if not osp.exists(save_exp_dir): + os.makedirs(save_exp_dir) + + + if FLAGS.task == 'conceptcombine': + conceptcombine(sess, kvs, data, latents, save_exp_dir) + elif FLAGS.task == 'labeldiscover': + labeldiscover(sess, kvs, data, latents, save_exp_dir) + elif FLAGS.task == 'gentest': + save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos)) + if not osp.exists(save_exp_dir): + os.makedirs(save_exp_dir) + + gentest(sess, kvs, data, latents, save_exp_dir) + elif FLAGS.task == 'genbaseline': + save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos)) + if not osp.exists(save_exp_dir): + os.makedirs(save_exp_dir) + + if FLAGS.plot_curve: + mse_losses = [] + for frac in [i/10 for i in range(11)]: + mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac) + mse_losses.append(mse_loss) + np.save("mse_baseline_comb.npy", mse_losses) + else: + genbaseline(sess, kvs, data, latents, save_exp_dir) + + + +if __name__ == "__main__": + main() diff --git a/ebm_sandbox.py b/ebm_sandbox.py new file mode 100644 index 0000000..531d517 --- /dev/null +++ b/ebm_sandbox.py @@ -0,0 +1,981 @@ +import tensorflow as tf +import math +from tqdm import tqdm +from tensorflow.python.platform import flags +from torch.utils.data import DataLoader +import torch +from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, DspritesNet +from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, DSprites +from utils import optimistic_restore, set_seed +import os.path as osp +import numpy as np +from baselines.logger import TensorBoardOutputFormat +from scipy.misc import imsave +import os +import sklearn.metrics as sk +from baselines.common.tf_util import initialize +from scipy.linalg import eig +import matplotlib.pyplot as plt + +# set_seed(1) + +flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') +flags.DEFINE_string('dataset', 'cifar10', 'omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites') +flags.DEFINE_string('logdir', 'sandbox_cachedir', 'location where log of experiments will be stored') +flags.DEFINE_string('task', 'label', 'using conditional energy based models for classification' + 'anticorrupt: restore salt and pepper noise),' + ' boxcorrupt: restore empty portion of image' + 'or crossclass: change images from one class to another' + 'or cycleclass: view image change across a label' + 'or nearestneighbor which returns the nearest images in the test set' + 'or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)' + 'or mixenergy to evaluate out of distribution generalization compared to other datasets') +flags.DEFINE_bool('hessian', True, 'Whether to use the hessian or the Jacobian for latent traversals') +flags.DEFINE_string('exp', 'default', 'name of experiments') +flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel') +flags.DEFINE_integer('batch_size', 32, 'Size of inputs') +flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from') +flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') +flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') +flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') +flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') +flags.DEFINE_bool('train', True, 'Whether to train or test network') +flags.DEFINE_bool('single', False, 'whether to use one sample to debug') +flags.DEFINE_bool('cclass', True, 'whether to use a conditional model (required for task label)') +flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') +flags.DEFINE_float('step_lr', 10.0, 'step size for updates on label') +flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images') +flags.DEFINE_bool('large_model', False, 'Whether to use a large model') +flags.DEFINE_bool('larger_model', False, 'Whether to use a larger model') +flags.DEFINE_bool('wider_model', False, 'Whether to use a widermodel model') +flags.DEFINE_bool('svhn', False, 'Whether to test on SVHN') + +# Conditions for mixenergy (outlier detection) +flags.DEFINE_bool('svhnmix', False, 'Whether to test mix on SVHN') +flags.DEFINE_bool('cifar100mix', False, 'Whether to test mix on CIFAR100') +flags.DEFINE_bool('texturemix', False, 'Whether to test mix on Textures dataset') +flags.DEFINE_bool('randommix', False, 'Whether to test mix on random dataset') + +# Conditions for label task (adversarial classification) +flags.DEFINE_integer('lival', 8, 'Value of constraint for li') +flags.DEFINE_integer('l2val', 40, 'Value of constraint for l2') +flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run') +flags.DEFINE_integer('lnorm', -1, 'linfinity is -1, l2 norm is 2') +flags.DEFINE_bool('labelgrid', False, 'Make a grid of labels') + +# Conditions on which models to use +flags.DEFINE_bool('cond_pos', True, 'whether to condition on position') +flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation') +flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape') +flags.DEFINE_bool('cond_size', True, 'whether to condition on scale') + +FLAGS = flags.FLAGS + +def rescale_im(im): + im = np.clip(im, 0, 1) + return np.round(im * 255).astype(np.uint8) + +def label(dataloader, test_dataloader, target_vars, sess, l1val=8, l2val=40): + X = target_vars['X'] + Y = target_vars['Y'] + Y_GT = target_vars['Y_GT'] + accuracy = target_vars['accuracy'] + train_op = target_vars['train_op'] + l1_norm = target_vars['l1_norm'] + l2_norm = target_vars['l2_norm'] + + label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10)) + label_init = label_init / label_init.sum(axis=1, keepdims=True) + + label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1)) + label_init = np.reshape(label_init, (-1, 10)) + + for i in range(1): + emp_accuracies = [] + + for data_corrupt, data, label_gt in tqdm(test_dataloader): + feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val} + emp_accuracy = sess.run([accuracy], feed_dict) + emp_accuracies.append(emp_accuracy) + print(np.array(emp_accuracies).mean()) + + print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val)) + + return np.array(emp_accuracies).mean() + + +def labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=8, l2val=40): + X = target_vars['X'] + Y = target_vars['Y'] + Y_GT = target_vars['Y_GT'] + accuracy = target_vars['accuracy'] + train_op = target_vars['train_op'] + l1_norm = target_vars['l1_norm'] + l2_norm = target_vars['l2_norm'] + + label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10)) + label_init = label_init / label_init.sum(axis=1, keepdims=True) + + label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1)) + label_init = np.reshape(label_init, (-1, 10)) + + itr = 0 + + if FLAGS.train: + for i in range(1): + for data_corrupt, data, label_gt in tqdm(dataloader): + feed_dict = {X: data, Y_GT: label_gt, Y: label_init} + acc, _ = sess.run([accuracy, train_op], feed_dict) + + itr += 1 + + if itr % 10 == 0: + print(acc) + + saver.save(sess, osp.join(savedir, "model_supervised")) + + saver.restore(sess, osp.join(savedir, "model_supervised")) + + + for i in range(1): + emp_accuracies = [] + + for data_corrupt, data, label_gt in tqdm(test_dataloader): + feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val} + emp_accuracy = sess.run([accuracy], feed_dict) + emp_accuracies.append(emp_accuracy) + print(np.array(emp_accuracies).mean()) + + + print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val)) + + return np.array(emp_accuracies).mean() + + +def energyeval(dataloader, test_dataloader, target_vars, sess): + X = target_vars['X'] + Y_GT = target_vars['Y_GT'] + energy = target_vars['energy'] + energy_end = target_vars['energy_end'] + + test_energies = [] + train_energies = [] + for data_corrupt, data, label_gt in tqdm(test_dataloader): + feed_dict = {X: data, Y_GT: label_gt} + test_energy = sess.run([energy], feed_dict)[0] + test_energies.extend(list(test_energy)) + + for data_corrupt, data, label_gt in tqdm(dataloader): + feed_dict = {X: data, Y_GT: label_gt} + train_energy = sess.run([energy], feed_dict)[0] + train_energies.extend(list(train_energy)) + + print(len(train_energies)) + print(len(test_energies)) + + print("Train energies of {} with std {}".format(np.mean(train_energies), np.std(train_energies))) + print("Test energies of {} with std {}".format(np.mean(test_energies), np.std(test_energies))) + + np.save("train_ebm.npy", train_energies) + np.save("test_ebm.npy", test_energies) + + +def energyevalmix(dataloader, test_dataloader, target_vars, sess): + X = target_vars['X'] + Y_GT = target_vars['Y_GT'] + energy = target_vars['energy'] + + if FLAGS.svhnmix: + dataset = Svhn(train=False) + test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False) + test_iter = iter(test_dataloader_val) + elif FLAGS.cifar100mix: + dataset = Cifar100(train=False) + test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False) + test_iter = iter(test_dataloader_val) + elif FLAGS.texturemix: + dataset = Textures() + test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False) + test_iter = iter(test_dataloader_val) + + probs = [] + labels = [] + negs = [] + pos = [] + for data_corrupt, data, label_gt in tqdm(test_dataloader): + data = data.numpy() + data_corrupt = data_corrupt.numpy() + if FLAGS.svhnmix: + _, data_mix, _ = test_iter.next() + elif FLAGS.cifar100mix: + _, data_mix, _ = test_iter.next() + elif FLAGS.texturemix: + _, data_mix, _ = test_iter.next() + elif FLAGS.randommix: + data_mix = np.random.randn(FLAGS.batch_size, 32, 32, 3) * 0.5 + 0.5 + else: + data_idx = np.concatenate([np.arange(1, data.shape[0]), [0]]) + data_other = data[data_idx] + data_mix = (data + data_other) / 2 + + data_mix = data_mix[:data.shape[0]] + + if FLAGS.cclass: + # It's unfair to take a random class + label_gt= np.tile(np.eye(10), (data.shape[0], 1, 1)) + label_gt = label_gt.reshape(data.shape[0] * 10, 10) + data_mix = np.tile(data_mix[:, None, :, :, :], (1, 10, 1, 1, 1)) + data = np.tile(data[:, None, :, :, :], (1, 10, 1, 1, 1)) + + data_mix = data_mix.reshape(-1, 32, 32, 3) + data = data.reshape(-1, 32, 32, 3) + + + feed_dict = {X: data, Y_GT: label_gt} + feed_dict_neg = {X: data_mix, Y_GT: label_gt} + + pos_energy = sess.run([energy], feed_dict)[0] + neg_energy = sess.run([energy], feed_dict_neg)[0] + + if FLAGS.cclass: + pos_energy = pos_energy.reshape(-1, 10).min(axis=1) + neg_energy = neg_energy.reshape(-1, 10).min(axis=1) + + probs.extend(list(-1*pos_energy)) + probs.extend(list(-1*neg_energy)) + pos.extend(list(-1*pos_energy)) + negs.extend(list(-1*neg_energy)) + labels.extend([1]*pos_energy.shape[0]) + labels.extend([0]*neg_energy.shape[0]) + + pos, negs = np.array(pos), np.array(negs) + np.save("pos.npy", pos) + np.save("neg.npy", negs) + auroc = sk.roc_auc_score(labels, probs) + print("Roc score of {}".format(auroc)) + + +def anticorrupt(dataloader, weights, model, target_vars, logdir, sess): + X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'] + for data_corrupt, data, label_gt in tqdm(dataloader): + data, label_gt = data.numpy(), label_gt.numpy() + + noise = np.random.uniform(0, 1, size=[data.shape[0], data.shape[1], data.shape[2]]) + low_mask = noise < 0.05 + high_mask = (noise > 0.05) & (noise < 0.1) + + print(high_mask.shape) + + data_corrupt = data.copy() + data_corrupt[low_mask] = 0.1 + data_corrupt[high_mask] = 0.9 + data_corrupt_init = data_corrupt + + for i in range(5): + feed_dict = {X: data_corrupt, Y_GT: label_gt} + data_corrupt = sess.run([X_final], feed_dict)[0] + + data_uncorrupt = data_corrupt + data_corrupt, data_uncorrupt, data = rescale_im(data_corrupt_init), rescale_im(data_uncorrupt), rescale_im(data) + + panel_im = np.zeros((32*20, 32*3, 3)).astype(np.uint8) + + for i in range(20): + panel_im[32*i:32*i+32, :32] = data_corrupt[i] + panel_im[32*i:32*i+32, 32:64] = data_uncorrupt[i] + panel_im[32*i:32*i+32, 64:] = data[i] + + imsave(osp.join(logdir, "anticorrupt.png"), panel_im) + assert False + + +def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess): + X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'] + eval_im = 10000 + + data_diff = [] + for data_corrupt, data, label_gt in tqdm(dataloader): + data, label_gt = data.numpy(), label_gt.numpy() + data_uncorrupts = [] + + data_corrupt = data.copy() + data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3)) + + data_corrupt_init = data_corrupt + + for j in range(10): + feed_dict = {X: data_corrupt, Y_GT: label_gt} + data_corrupt = sess.run([X_final], feed_dict)[0] + + val = np.mean(np.square(data_corrupt - data), axis=(1, 2, 3)) + data_diff.extend(list(val)) + + if len(data_diff) > eval_im: + break + + print("Mean {} and std {} for train dataloader".format(np.mean(data_diff), np.std(data_diff))) + + np.save("data_diff_train_image.npy", data_diff) + + data_diff = [] + + for data_corrupt, data, label_gt in tqdm(test_dataloader): + data, label_gt = data.numpy(), label_gt.numpy() + data_uncorrupts = [] + + data_corrupt = data.copy() + data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3)) + + data_corrupt_init = data_corrupt + + for j in range(10): + feed_dict = {X: data_corrupt, Y_GT: label_gt} + data_corrupt = sess.run([X_final], feed_dict)[0] + + data_diff.extend(list(np.mean(np.square(data_corrupt - data), axis=(1, 2, 3)))) + + if len(data_diff) > eval_im: + break + + print("Mean {} and std {} for test dataloader".format(np.mean(data_diff), np.std(data_diff))) + + np.save("data_diff_test_image.npy", data_diff) + + +def crossclass(dataloader, weights, model, target_vars, logdir, sess): + X, Y_GT, X_mods, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods'], target_vars['X_final'] + for data_corrupt, data, label_gt in tqdm(dataloader): + data, label_gt = data.numpy(), label_gt.numpy() + data_corrupt = data.copy() + data_corrupt[1:] = data_corrupt[0:-1] + data_corrupt[0] = data[-1] + + data_mods = [] + data_mod = data_corrupt + + for i in range(10): + data_mods.append(data_mod) + + feed_dict = {X: data_mod, Y_GT: label_gt} + data_mod = sess.run(X_final, feed_dict) + + + + data_corrupt, data = rescale_im(data_corrupt), rescale_im(data) + + data_mods = [rescale_im(data_mod) for data_mod in data_mods] + + panel_im = np.zeros((32*20, 32*(len(data_mods) + 2), 3)).astype(np.uint8) + + for i in range(20): + panel_im[32*i:32*i+32, :32] = data_corrupt[i] + + for j in range(len(data_mods)): + panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i] + + panel_im[32*i:32*i+32, -32:] = data[i] + + imsave(osp.join(logdir, "crossclass.png"), panel_im) + assert False + + +def cycleclass(dataloader, weights, model, target_vars, logdir, sess): + # X, Y_GT, X_final, X_targ = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'], target_vars['X_targ'] + X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'] + for data_corrupt, data, label_gt in tqdm(dataloader): + data, label_gt = data.numpy(), label_gt.numpy() + data_corrupt = data_corrupt.numpy() + + + data_mods = [] + x_curr = data_corrupt + x_target = np.random.uniform(0, 1, data_corrupt.shape) + # x_target = np.tile(x_target, (1, 32, 32, 1)) + + + for i in range(20): + feed_dict = {X: x_curr, Y_GT: label_gt} + x_curr_new = sess.run(X_final, feed_dict) + x_curr = x_curr_new + data_mods.append(x_curr_new) + + if i > 30: + x_target = np.random.uniform(0, 1, data_corrupt.shape) + + data_corrupt, data = rescale_im(data_corrupt), rescale_im(data) + + data_mods = [rescale_im(data_mod) for data_mod in data_mods] + + panel_im = np.zeros((32*100, 32*(len(data_mods) + 2), 3)).astype(np.uint8) + + for i in range(100): + panel_im[32*i:32*i+32, :32] = data_corrupt[i] + + for j in range(len(data_mods)): + panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i] + + panel_im[32*i:32*i+32, -32:] = data[i] + + imsave(osp.join(logdir, "cycleclass.png"), panel_im) + assert False + + +def democlass(dataloader, weights, model, target_vars, logdir, sess): + X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'] + panel_im = np.zeros((5*32, 10*32, 3)).astype(np.uint8) + for i in range(10): + data_corrupt = np.random.uniform(0, 1, (5, 32, 32, 3)) + label_gt = np.tile(np.eye(10)[i:i+1], (5, 1)) + + feed_dict = {X: data_corrupt, Y_GT: label_gt} + x_final = sess.run([X_final], feed_dict)[0] + + x_final = rescale_im(x_final) + + row = i // 2 + col = i % 2 + + start_idx = col * 32 * 5 + row_idx = row * 32 + + for j in range(5): + panel_im[row_idx:row_idx+32, start_idx+j*32:start_idx+(j+1) * 32] = x_final[j] + + imsave(osp.join(logdir, "democlass.png"), panel_im) + + +def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars): + l1_norm = tf.placeholder(shape=(), dtype=tf.float32) + l2_norm = tf.placeholder(shape=(), dtype=tf.float32) + + def compute_logit(X, stop_grad=False, num_steps=0): + batch_size = tf.shape(X)[0] + X = tf.reshape(X, (batch_size, 1, 32, 32, 3)) + X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3)) + Y_new = tf.reshape(Y, (batch_size*10, 10)) + + X_min = X - 8 / 255. + X_max = X + 8 / 255. + + for i in range(num_steps): + X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005) + + energy_noise = model.forward(X, weights, label=Y, reuse=True) + x_grad = tf.gradients(energy_noise, [X])[0] + + + if FLAGS.proj_norm != 0.0: + x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) + + X = X - FLAGS.step_lr * x_grad + X = tf.maximum(tf.minimum(X, X_max), X_min) + + energy = model.forward(X, weight, label=Y_new) + energy = -tf.reshape(energy, (batch_size, 10)) + + if stop_grad: + energy = tf.stop_gradient(energy) + + return energy + + for i in range(FLAGS.pgd): + if FLAGS.train: + break + + print("Constructed loop {} of pgd attack".format(i)) + X_init = X + if i == 0: + X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255. + + logit = compute_logit(X) + loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit) + + x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255. + X = X + 2 * x_grad + + if FLAGS.lnorm == -1: + X = tf.maximum(tf.minimum(X, X_max), X_min) + elif FLAGS.lnorm == 2: + X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3]) + + + energy = compute_logit(X, num_steps=0) + logits = energy + labels = tf.argmax(Y_GT, axis=1) + loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logits) + + + optimizer = tf.train.AdamOptimizer(1e-3) + train_op = optimizer.minimize(loss) + accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, axis=1), labels) + + target_vars['accuracy'] = accuracy + target_vars['train_op'] = train_op + target_vars['l1_norm'] = l1_norm + target_vars['l2_norm'] = l2_norm + + +def construct_latent(weights, X, Y_GT, model, target_vars): + + eps = 0.001 + X_init = X[0:1] + + def traversals(model, X, weights, Y_GT): + if FLAGS.hessian: + e_pos = model.forward(X, weights, label=Y_GT) + hessian = tf.hessians(e_pos, X) + hessian = tf.reshape(hessian, (1, 64*64, 64*64))[0] + e, v = tf.linalg.eigh(hessian) + else: + latent = model.forward(X, weights, label=Y_GT, return_logit=True) + latents = tf.split(latent, 128, axis=1) + jacobian = [tf.gradients(latent, X)[0] for latent in latents] + jacobian = tf.stack(jacobian, axis=1) + jacobian = tf.reshape(jacobian, (tf.shape(jacobian)[1], tf.shape(jacobian)[1], 64*64)) + s, _, v = tf.linalg.svd(jacobian) + + return v + + + var_scale = 1.0 + n = 3 + xs = [] + + v = traversals(model, X_init, weights, Y_GT) + + for i in range(n): + var = tf.reshape(v[:, i], (1, 64, 64)) + X_plus = X_init - var_scale * var + X_min = X_init + var_scale * var + + xs.extend([X_plus, X_min]) + + x_stack = tf.stack(xs, axis=0) + + e_pos_hess_modify = model.forward(x_stack, weights, label=Y_GT) + + for i in range(20): + x_stack = x_stack + tf.random_normal(tf.shape(x_stack), mean=0.0, stddev=0.005) + e_pos = model.forward(x_stack, weights, label=Y_GT) + + x_grad = tf.gradients(e_pos, [x_stack])[0] + x_stack = x_stack - 4*FLAGS.step_lr * x_grad + + x_stack = tf.clip_by_value(x_stack, 0, 1) + + x_mods = tf.split(X, 6) + + eigs = [] + for j in range(6): + x_mod = x_mods[j] + v = traversals(model, x_mod, weights, Y_GT) + + idx = j // 2 + var = tf.reshape(v[:, idx], (1, 64, 64)) + + if j % 2 == 1: + x_mod = x_mod + var_scale * var + eigs.append(var) + else: + x_mod = x_mod - var_scale * var + eigs.append(-var) + + x_mod = tf.clip_by_value(x_mod, 0, 1) + x_mods[j] = x_mod + + x_mods_stack = tf.stack(x_mods, axis=0) + + eigs_stack = tf.stack(eigs, axis=0) + energys = [] + + for i in range(20): + x_mods_stack = x_mods_stack + tf.random_normal(tf.shape(x_mods_stack), mean=0.0, stddev=0.005) + e_pos = model.forward(x_mods_stack, weights, label=Y_GT) + + x_grad = tf.gradients(e_pos, [x_mods_stack])[0] + x_mods_stack = x_mods_stack - 4*FLAGS.step_lr * x_grad + # x_mods_stack = x_mods_stack + 0.1 * eigs_stack + + x_mods_stack = tf.clip_by_value(x_mods_stack, 0, 1) + + energys.append(e_pos) + + x_refine = x_mods_stack + es = tf.stack(energys, axis=0) + + # target_vars['hessian'] = hessian + # target_vars['e'] = e + target_vars['v'] = v + target_vars['x_stack'] = x_stack + target_vars['x_refine'] = x_refine + target_vars['es'] = es + # target_vars['e_base'] = e_pos_base + + +def latent(test_dataloader, weights, model, target_vars, sess): + X = target_vars['X'] + Y_GT = target_vars['Y_GT'] + # hessian = target_vars['hessian'] + # e = target_vars['e'] + v = target_vars['v'] + x_stack = target_vars['x_stack'] + x_refine = target_vars['x_refine'] + es = target_vars['es'] + # e_pos_base = target_vars['e_base'] + # e_pos_hess_modify = target_vars['e_pos_hessian'] + + data_corrupt, data, label_gt = iter(test_dataloader).next() + data = data.numpy() + x_init = np.tile(data[0:1], (6, 1, 1)) + x_mod, = sess.run([x_stack], {X: data}) + # print("Value of original starting image: ", e_pos) + # print("Value of energy of hessian: ", e_pos_hess) + x_mod = x_mod.squeeze() + + n = 6 + x_mod_list = [x_init, x_mod] + + for i in range(n): + x_mod, evals = sess.run([x_refine, es], {X: x_mod}) + x_mod = x_mod.squeeze() + x_mod_list.append(x_mod) + print("Value of energies after evaluation: ", evals) + + x_mod_list = x_mod_list[:] + + + series_xmod = np.stack(x_mod_list, axis=1) + series_header = np.tile(data[0:1, None, :, :], (1, len(x_mod_list), 1, 1)) + + series_total = np.concatenate([series_header, series_xmod], axis=0) + + series_total_full = np.ones((*series_total.shape[:-2], 66, 66)) + + series_total_full[:, :, 1:-1, 1:-1] = series_total + + series_total = series_total_full + + series_total = series_total.transpose((0, 2, 1, 3)).reshape((-1, len(x_mod_list)*66)) + im_total = rescale_im(series_total) + imsave("latent_comb.png", im_total) + + +def construct_label(weights, X, Y, Y_GT, model, target_vars): + # for i in range(FLAGS.num_steps): + # Y = Y + tf.random_normal(tf.shape(Y), mean=0.0, stddev=0.03) + # e = model.forward(X, weights, label=Y) + + # Y_grad = tf.clip_by_value(tf.gradients(e, [Y])[0], -1, 1) + # Y = Y - 0.1 * Y_grad + # Y = tf.clip_by_value(Y, 0, 1) + + # Y = Y / tf.reduce_sum(Y, axis=[1], keepdims=True) + + e_bias = tf.get_variable('e_bias', shape=10, initializer=tf.initializers.zeros()) + l1_norm = tf.placeholder(shape=(), dtype=tf.float32) + l2_norm = tf.placeholder(shape=(), dtype=tf.float32) + + def compute_logit(X, stop_grad=False, num_steps=0): + batch_size = tf.shape(X)[0] + X = tf.reshape(X, (batch_size, 1, 32, 32, 3)) + X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3)) + Y_new = tf.reshape(Y, (batch_size*10, 10)) + + X_min = X - 8 / 255. + X_max = X + 8 / 255. + + for i in range(num_steps): + X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005) + + energy_noise = model.forward(X, weights, label=Y, reuse=True) + x_grad = tf.gradients(energy_noise, [X])[0] + + + if FLAGS.proj_norm != 0.0: + x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) + + X = X - FLAGS.step_lr * x_grad + X = tf.maximum(tf.minimum(X, X_max), X_min) + + energy = model.forward(X, weights, label=Y_new) + energy = -tf.reshape(energy, (batch_size, 10)) + + if stop_grad: + energy = tf.stop_gradient(energy) + + return energy + + + # eps_norm = 30 + X_min = X - l1_norm / 255. + X_max = X + l1_norm / 255. + + for i in range(FLAGS.pgd): + print("Constructed loop {} of pgd attack".format(i)) + X_init = X + if i == 0: + X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255. + + logit = compute_logit(X) + loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit) + + x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255. + X = X + 2 * x_grad + + if FLAGS.lnorm == -1: + X = tf.maximum(tf.minimum(X, X_max), X_min) + elif FLAGS.lnorm == 2: + X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3]) + + energy_stopped = compute_logit(X, stop_grad=True, num_steps=FLAGS.num_steps) + e_bias + + # # Y = tf.Print(Y, [Y]) + labels = tf.argmax(Y_GT, axis=1) + # max_z = tf.argmax(energy_stopped, axis=1) + + loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=energy_stopped) + optimizer = tf.train.AdamOptimizer(1e-2) + train_op = optimizer.minimize(loss) + + accuracy = tf.contrib.metrics.accuracy(tf.argmax(energy_stopped, axis=1), labels) + target_vars['accuracy'] = accuracy + target_vars['train_op'] = train_op + target_vars['l1_norm'] = l1_norm + target_vars['l2_norm'] = l2_norm + + +def construct_energy(weights, X, Y, Y_GT, model, target_vars): + energy = model.forward(X, weights, label=Y_GT) + + for i in range(FLAGS.num_steps): + X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005) + + energy_noise = model.forward(X, weights, label=Y_GT, reuse=True) + x_grad = tf.gradients(energy_noise, [X])[0] + + if FLAGS.proj_norm != 0.0: + x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) + + X = X - FLAGS.step_lr * x_grad + X = tf.clip_by_value(X, 0, 1) + + + target_vars['energy'] = energy + target_vars['energy_end'] = energy_noise + + +def construct_steps(weights, X, Y_GT, model, target_vars): + n = 50 + scale_fac = 1.0 + + # if FLAGS.task == 'cycleclass': + # scale_fac = 10.0 + + X_mods = [] + X = tf.identity(X) + + mask = np.zeros((1, 32, 32, 3)) + + if FLAGS.task == "boxcorrupt": + mask[:, 16:, :, :] = 1 + else: + mask[:, :, :, :] = 1 + + mask = tf.Variable(tf.convert_to_tensor(mask, dtype=tf.float32), trainable=False) + + # X_targ = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32) + + for i in range(FLAGS.num_steps): + X_old = X + X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005*scale_fac) * mask + + energy_noise = model.forward(X, weights, label=Y_GT, reuse=True) + x_grad = tf.gradients(energy_noise, [X])[0] + + if FLAGS.proj_norm != 0.0: + x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) + + X = X - FLAGS.step_lr * x_grad * scale_fac * mask + X = tf.clip_by_value(X, 0, 1) + + if i % n == (n-1): + X_mods.append(X) + + print("Constructing step {}".format(i)) + + target_vars['X_final'] = X + target_vars['X_mods'] = X_mods + + +def nearest_neighbor(dataset, sess, target_vars, logdir): + X = target_vars['X'] + Y_GT = target_vars['Y_GT'] + x_final = target_vars['X_final'] + + noise = np.random.uniform(0, 1, size=[10, 32, 32, 3]) + # label = np.random.randint(0, 10, size=[10]) + label = np.eye(10) + + coarse = noise + + for i in range(10): + x_new = sess.run([x_final], {X:coarse, Y_GT:label})[0] + coarse = x_new + + x_new_dense = x_new.reshape(10, 1, 32*32*3) + dataset_dense = dataset.reshape(1, 50000, 32*32*3) + + diff = np.square(x_new_dense - dataset_dense).sum(axis=2) + diff_idx = np.argsort(diff, axis=1) + + panel = np.zeros((32*10, 32*6, 3)) + + dataset_rescale = rescale_im(dataset) + x_new_rescale = rescale_im(x_new) + + for i in range(10): + panel[i*32:i*32+32, :32] = x_new_rescale[i] + for j in range(5): + panel[i*32:i*32+32, 32*j+32:32*j+64] = dataset_rescale[diff_idx[i, j]] + + imsave(osp.join(logdir, "nearest.png"), panel) + + +def main(): + + if FLAGS.dataset == "cifar10": + dataset = Cifar10(train=True, noise=False) + test_dataset = Cifar10(train=False, noise=False) + else: + dataset = Imagenet(train=True) + test_dataset = Imagenet(train=False) + + if FLAGS.svhn: + dataset = Svhn(train=True) + test_dataset = Svhn(train=False) + + if FLAGS.task == 'latent': + dataset = DSprites() + test_dataset = dataset + + dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) + test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) + + hidden_dim = 128 + + if FLAGS.large_model: + model = ResNet32Large(num_filters=hidden_dim) + elif FLAGS.larger_model: + model = ResNet32Larger(num_filters=hidden_dim) + elif FLAGS.wider_model: + if FLAGS.dataset == 'imagenet': + model = ResNet32Wider(num_filters=196, train=False) + else: + model = ResNet32Wider(num_filters=256, train=False) + else: + model = ResNet32(num_filters=hidden_dim) + + if FLAGS.task == 'latent': + model = DspritesNet() + + weights = model.construct_weights('context_{}'.format(0)) + + total_parameters = 0 + for variable in tf.trainable_variables(): + # shape is an array of tf.Dimension + shape = variable.get_shape() + variable_parameters = 1 + for dim in shape: + variable_parameters *= dim.value + total_parameters += variable_parameters + print("Model has a total of {} parameters".format(total_parameters)) + + config = tf.ConfigProto() + sess = tf.InteractiveSession() + + if FLAGS.task == 'latent': + X = tf.placeholder(shape=(None, 64, 64), dtype = tf.float32) + else: + X = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32) + + if FLAGS.dataset == "cifar10": + Y = tf.placeholder(shape=(None, 10), dtype = tf.float32) + Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32) + elif FLAGS.dataset == "imagenet": + Y = tf.placeholder(shape=(None, 1000), dtype = tf.float32) + Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32) + + target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT} + + if FLAGS.task == 'label': + construct_label(weights, X, Y, Y_GT, model, target_vars) + elif FLAGS.task == 'labelfinetune': + construct_finetune_label(weights, X, Y, Y_GT, model, target_vars, ) + elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy': + construct_energy(weights, X, Y, Y_GT, model, target_vars) + elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor': + construct_steps(weights, X, Y_GT, model, target_vars) + elif FLAGS.task == 'latent': + construct_latent(weights, X, Y_GT, model, target_vars) + + sess.run(tf.global_variables_initializer()) + saver = loader = tf.train.Saver(max_to_keep=10) + savedir = osp.join('cachedir', FLAGS.exp) + logdir = osp.join(FLAGS.logdir, FLAGS.exp) + if not osp.exists(logdir): + os.makedirs(logdir) + + initialize() + if FLAGS.resume_iter != -1: + model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter)) + resume_itr = FLAGS.resume_iter + + if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy": + optimistic_restore(sess, model_file) + # saver.restore(sess, model_file) + else: + # optimistic_restore(sess, model_file) + saver.restore(sess, model_file) + + if FLAGS.task == 'label': + if FLAGS.labelgrid: + vals = [] + if FLAGS.lnorm == -1: + for i in range(31): + accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i) + vals.append(accuracies) + elif FLAGS.lnorm == 2: + for i in range(0, 100, 5): + accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i) + vals.append(accuracies) + + np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals) + else: + label(dataloader, test_dataloader, target_vars, sess) + elif FLAGS.task == 'labelfinetune': + labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val) + elif FLAGS.task == 'energyeval': + energyeval(dataloader, test_dataloader, target_vars, sess) + elif FLAGS.task == 'mixenergy': + energyevalmix(dataloader, test_dataloader, target_vars, sess) + elif FLAGS.task == 'anticorrupt': + anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess) + elif FLAGS.task == 'boxcorrupt': + # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess) + boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess) + elif FLAGS.task == 'crossclass': + crossclass(test_dataloader, weights, model, target_vars, logdir, sess) + elif FLAGS.task == 'cycleclass': + cycleclass(test_dataloader, weights, model, target_vars, logdir, sess) + elif FLAGS.task == 'democlass': + democlass(test_dataloader, weights, model, target_vars, logdir, sess) + elif FLAGS.task == 'nearestneighbor': + # print(dir(dataset)) + # print(type(dataset)) + nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir) + elif FLAGS.task == 'latent': + latent(test_dataloader, weights, model, target_vars, sess) + + +if __name__ == "__main__": + main() diff --git a/fid.py b/fid.py new file mode 100644 index 0000000..7aee938 --- /dev/null +++ b/fid.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. + +The FID metric calculates the distance between two distributions of images. +Typically, we have summary statistics (mean & covariance matrix) of one +of these distributions, while the 2nd distribution is given by a GAN. + +When run as a stand-alone program, it compares the distribution of +images that are stored as PNG/JPEG at a specified location with a +distribution given by summary statistics (in pickle format). + +The FID is calculated by assuming that X_1 and X_2 are the activations of +the pool_3 layer of the inception net for generated samples and real world +samples respectivly. + +See --help to see further details. +''' + +from __future__ import absolute_import, division, print_function +import numpy as np +import os +import gzip, pickle +import tensorflow as tf +from scipy.misc import imread +from scipy import linalg +import pathlib +import urllib +import tarfile +import warnings + +MODEL_DIR = '/tmp/imagenet' +DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' +pool3 = None + +class InvalidFIDException(Exception): + pass + +#------------------------------------------------------------------------------- +def get_fid_score(images, images_gt): + images = np.stack(images, 0) + images_gt = np.stack(images_gt, 0) + + with tf.Session() as sess: + m1, s1 = calculate_activation_statistics(images, sess) + m2, s2 = calculate_activation_statistics(images_gt, sess) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + print("Obtained fid value of {}".format(fid_value)) + return fid_value + + +def create_inception_graph(pth): + """Creates a graph from saved GraphDef file.""" + # Creates graph from saved graph_def.pb. + with tf.gfile.FastGFile( pth, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString( f.read()) + _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') +#------------------------------------------------------------------------------- + + +# code for handling inception net derived from +# https://github.com/openai/improved-gan/blob/master/inception_score/model.py +def _get_inception_layer(sess): + """Prepares inception net for batched usage and returns pool_3 layer. """ + layername = 'FID_Inception_Net/pool_3:0' + pool3 = sess.graph.get_tensor_by_name(layername) + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + if shape._dims != []: + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__['_shape_val'] = tf.TensorShape(new_shape) + return pool3 +#------------------------------------------------------------------------------- + + +def get_activations(images, sess, batch_size=50, verbose=False): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 256. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the disposable hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- A numpy array of dimension (num images, 2048) that contains the + activations of the given tensor when feeding inception with the query tensor. + """ + # inception_layer = _get_inception_layer(sess) + d0 = images.shape[0] + if batch_size > d0: + print("warning: batch size is bigger than the data size. setting batch size to data size") + batch_size = d0 + n_batches = d0//batch_size + n_used_imgs = n_batches*batch_size + pred_arr = np.empty((n_used_imgs,2048)) + for i in range(n_batches): + if verbose: + print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) + start = i*batch_size + end = start + batch_size + batch = images[start:end] + pred = sess.run(pool3, {'ExpandDims:0': batch}) + pred_arr[start:end] = pred.reshape(batch_size,-1) + if verbose: + print(" done") + return pred_arr +#------------------------------------------------------------------------------- + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean +#------------------------------------------------------------------------------- + + +def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): + """Calculation of the statistics used by the FID. + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 255. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the available hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the incption model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the incption model. + """ + act = get_activations(images, sess, batch_size, verbose) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma +#------------------------------------------------------------------------------- + + +#------------------------------------------------------------------------------- +# The following functions aren't needed for calculating the FID +# they're just here to make this module work as a stand-alone script +# for calculating FID scores +#------------------------------------------------------------------------------- +def check_or_download_inception(inception_path): + ''' Checks if the path to the inception file is valid, or downloads + the file if it is not present. ''' + INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' + if inception_path is None: + inception_path = '/tmp' + inception_path = pathlib.Path(inception_path) + model_file = inception_path / 'classify_image_graph_def.pb' + if not model_file.exists(): + print("Downloading Inception model") + from urllib import request + import tarfile + fn, _ = request.urlretrieve(INCEPTION_URL) + with tarfile.open(fn, mode='r') as f: + f.extract('classify_image_graph_def.pb', str(model_file.parent)) + return str(model_file) + + +def _handle_path(path, sess): + if path.endswith('.npz'): + f = np.load(path) + m, s = f['mu'][:], f['sigma'][:] + f.close() + else: + path = pathlib.Path(path) + files = list(path.glob('*.jpg')) + list(path.glob('*.png')) + x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) + m, s = calculate_activation_statistics(x, sess) + return m, s + + +def calculate_fid_given_paths(paths, inception_path): + ''' Calculates the FID of two paths. ''' + inception_path = check_or_download_inception(inception_path) + + for p in paths: + if not os.path.exists(p): + raise RuntimeError("Invalid path: %s" % p) + + create_inception_graph(str(inception_path)) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m1, s1 = _handle_path(paths[0], sess) + m2, s2 = _handle_path(paths[1], sess) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + +def _init_inception(): + global pool3 + if not os.path.exists(MODEL_DIR): + os.makedirs(MODEL_DIR) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(MODEL_DIR, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % ( + filename, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') + tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) + with tf.gfile.FastGFile(os.path.join( + MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, name='') + # Works with an arbitrary minibatch size. + with tf.Session() as sess: + pool3 = sess.graph.get_tensor_by_name('pool_3:0') + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + if shape._dims != []: + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__['_shape_val'] = tf.TensorShape(new_shape) + + +if pool3 is None: + _init_inception() diff --git a/hmc.py b/hmc.py new file mode 100644 index 0000000..68821c9 --- /dev/null +++ b/hmc.py @@ -0,0 +1,129 @@ +import tensorflow as tf +import numpy as np + +from tensorflow.python.platform import flags +flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes') + +FLAGS = flags.FLAGS + +def kinetic_energy(velocity): + """Kinetic energy of the current velocity (assuming a standard Gaussian) + (x dot x) / 2 + + Parameters + ---------- + velocity : tf.Variable + Vector of current velocity + + Returns + ------- + kinetic_energy : float + """ + return 0.5 * tf.square(velocity) + +def hamiltonian(position, velocity, energy_function): + """Computes the Hamiltonian of the current position, velocity pair + + H = U(x) + K(v) + + U is the potential energy and is = -log_posterior(x) + + Parameters + ---------- + position : tf.Variable + Position or state vector x (sample from the target distribution) + velocity : tf.Variable + Auxiliary velocity variable + energy_function + Function from state to position to 'energy' + = -log_posterior + + Returns + ------- + hamitonian : float + """ + batch_size = tf.shape(velocity)[0] + kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1)) + return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1]) + +def leapfrog_step(x0, + v0, + neg_log_posterior, + step_size, + num_steps): + + # Start by updating the velocity a half-step + v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0] + + # Initalize x to be the first step + x = x0 + step_size * v + + for i in range(num_steps): + # Compute gradient of the log-posterior with respect to x + gradient = tf.gradients(neg_log_posterior(x), x)[0] + + # Update velocity + v = v - step_size * gradient + + # x_clip = tf.clip_by_value(x, 0.0, 1.0) + # x = x_clip + # v_mask = 1 - 2 * tf.abs(tf.sign(x - x_clip)) + # v = v * v_mask + + # Update x + x = x + step_size * v + + # x = tf.clip_by_value(x, -0.01, 1.01) + + # x = tf.Print(x, [tf.reduce_min(x), tf.reduce_max(x), tf.reduce_mean(x)]) + + # Do a final update of the velocity for a half step + v = v - 0.5 * step_size * tf.gradients(neg_log_posterior(x), x)[0] + + # return new proposal state + return x, v + +def hmc(initial_x, + step_size, + num_steps, + neg_log_posterior): + """Summary + + Parameters + ---------- + initial_x : tf.Variable + Initial sample x ~ p + step_size : float + Step-size in Hamiltonian simulation + num_steps : int + Number of steps to take in Hamiltonian simulation + neg_log_posterior : str + Negative log posterior (unnormalized) for the target distribution + + Returns + ------- + sample : + Sample ~ target distribution + """ + + v0 = tf.random_normal(tf.shape(initial_x)) + x, v = leapfrog_step(initial_x, + v0, + step_size=step_size, + num_steps=num_steps, + neg_log_posterior=neg_log_posterior) + + orig = hamiltonian(initial_x, v0, neg_log_posterior) + current = hamiltonian(x, v, neg_log_posterior) + + prob_accept = tf.exp(orig - current) + + if FLAGS.proposal_debug: + prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]) + + uniform = tf.random_uniform(tf.shape(prob_accept)) + keep_mask = (prob_accept > uniform) + # print(keep_mask.get_shape()) + + x_new = tf.where(keep_mask, x, initial_x) + return x_new diff --git a/imagenet_demo.py b/imagenet_demo.py new file mode 100644 index 0000000..9d79395 --- /dev/null +++ b/imagenet_demo.py @@ -0,0 +1,73 @@ +from models import ResNet128 +import numpy as np +import os.path as osp +from tensorflow.python.platform import flags +import tensorflow as tf +import imageio +from utils import optimistic_restore + + +flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') +flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling') +flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics') +flags.DEFINE_integer('batch_size', 16, 'number of steps to run') +flags.DEFINE_string('exp', 'default', 'name of experiments') +flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from') +flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model') +flags.DEFINE_bool('cclass', True, 'conditional models') +flags.DEFINE_bool('use_attention', False, 'using attention') + +FLAGS = flags.FLAGS + +def rescale_im(im): + return np.clip(im * 256, 0, 255).astype(np.uint8) + + +if __name__ == "__main__": + model = ResNet128(num_filters=64) + X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) + LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) + + sess = tf.InteractiveSession() + weights = model.construct_weights("context_0") + + x_mod = X_NOISE + x_mod = x_mod + tf.random_normal(tf.shape(x_mod), + mean=0.0, + stddev=0.005) + + energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL, + reuse=True, stop_at_grad=False, stop_batch=True) + + x_grad = tf.gradients(energy_noise, [x_mod])[0] + energy_noise_old = energy_noise + + lr = FLAGS.step_lr + + x_last = x_mod - (lr) * x_grad + + x_mod = x_last + x_mod = tf.clip_by_value(x_mod, 0, 1) + x_output = x_mod + + sess.run(tf.global_variables_initializer()) + saver = loader = tf.train.Saver() + + logdir = osp.join(FLAGS.logdir, FLAGS.exp) + model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) + saver.restore(sess, model_file) + + lx = np.random.permutation(1000)[:16] + ims = [] + + # What to initialize sampling with. + x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3)) + labels = np.eye(1000)[lx] + + for i in range(FLAGS.num_steps): + e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels}) + ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3))) + + imageio.mimwrite('sample.gif', ims) + + diff --git a/imagenet_preprocessing.py b/imagenet_preprocessing.py new file mode 100644 index 0000000..cbfde0c --- /dev/null +++ b/imagenet_preprocessing.py @@ -0,0 +1,337 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Image pre-processing utilities. +""" +import tensorflow as tf + + +IMAGE_DEPTH = 3 # color images + +import tensorflow as tf + +# _R_MEAN = 123.68 +# _G_MEAN = 116.78 +# _B_MEAN = 103.94 +# _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN] +_CHANNEL_MEANS = [0.0, 0.0, 0.0] + +# The lower bound for the smallest side of the image for aspect-preserving +# resizing. For example, if an image is 500 x 1000, it will be resized to +# _RESIZE_MIN x (_RESIZE_MIN * 2). +_RESIZE_MIN = 128 + + +def _decode_crop_and_flip(image_buffer, bbox, num_channels): + """Crops the given image to a random part of the image, and randomly flips. + + We use the fused decode_and_crop op, which performs better than the two ops + used separately in series, but note that this requires that the image be + passed in as an un-decoded string Tensor. + + Args: + image_buffer: scalar string Tensor representing the raw JPEG image buffer. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + num_channels: Integer depth of the image buffer for decoding. + + Returns: + 3-D tensor with cropped image. + + """ + # A large fraction of image datasets contain a human-annotated bounding box + # delineating the region of the image containing the object of interest. We + # choose to create a new bounding box for the object which is a randomly + # distorted version of the human-annotated bounding box that obeys an + # allowed range of aspect ratios, sizes and overlap with the human-annotated + # bounding box. If no box is supplied, then we assume the bounding box is + # the entire image. + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + tf.image.extract_jpeg_shape(image_buffer), + bounding_boxes=bbox, + min_object_covered=0.1, + aspect_ratio_range=[0.75, 1.33], + area_range=[0.05, 1.0], + max_attempts=100, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Reassemble the bounding box in the format the crop op requires. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + + # Use the fused decode and crop op here, which is faster than each in series. + cropped = tf.image.decode_and_crop_jpeg( + image_buffer, crop_window, channels=num_channels) + + # Flip to add a little more random distortion in. + cropped = tf.image.random_flip_left_right(cropped) + return cropped + + +def _central_crop(image, crop_height, crop_width): + """Performs central crops of the given image list. + + Args: + image: a 3-D image tensor + crop_height: the height of the image following the crop. + crop_width: the width of the image following the crop. + + Returns: + 3-D tensor with cropped image. + """ + shape = tf.shape(input=image) + height, width = shape[0], shape[1] + + amount_to_be_cropped_h = (height - crop_height) + crop_top = amount_to_be_cropped_h // 2 + amount_to_be_cropped_w = (width - crop_width) + crop_left = amount_to_be_cropped_w // 2 + return tf.slice( + image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) + + +def _mean_image_subtraction(image, means, num_channels): + """Subtracts the given means from each image channel. + + For example: + means = [123.68, 116.779, 103.939] + image = _mean_image_subtraction(image, means) + + Note that the rank of `image` must be known. + + Args: + image: a tensor of size [height, width, C]. + means: a C-vector of values to subtract from each channel. + num_channels: number of color channels in the image that will be distorted. + + Returns: + the centered image. + + Raises: + ValueError: If the rank of `image` is unknown, if `image` has a rank other + than three or if the number of channels in `image` doesn't match the + number of values in `means`. + """ + if image.get_shape().ndims != 3: + raise ValueError('Input must be of size [height, width, C>0]') + + if len(means) != num_channels: + raise ValueError('len(means) must match the number of channels') + + # We have a 1-D tensor of means; convert to 3-D. + means = tf.expand_dims(tf.expand_dims(means, 0), 0) + + return image - means + + +def _smallest_size_at_least(height, width, resize_min): + """Computes new shape with the smallest side equal to `smallest_side`. + + Computes new shape with the smallest side equal to `smallest_side` while + preserving the original aspect ratio. + + Args: + height: an int32 scalar tensor indicating the current height. + width: an int32 scalar tensor indicating the current width. + resize_min: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. + + Returns: + new_height: an int32 scalar tensor indicating the new height. + new_width: an int32 scalar tensor indicating the new width. + """ + resize_min = tf.cast(resize_min, tf.float32) + + # Convert to floats to make subsequent calculations go smoothly. + height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) + + smaller_dim = tf.minimum(height, width) + scale_ratio = resize_min / smaller_dim + + # Convert back to ints to make heights and widths that TF ops will accept. + new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32) + new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32) + + return new_height, new_width + + +def _aspect_preserving_resize(image, resize_min): + """Resize images preserving the original aspect ratio. + + Args: + image: A 3-D image `Tensor`. + resize_min: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. + + Returns: + resized_image: A 3-D tensor containing the resized image. + """ + shape = tf.shape(input=image) + height, width = shape[0], shape[1] + + new_height, new_width = _smallest_size_at_least(height, width, resize_min) + + return _resize_image(image, new_height, new_width) + + +def _resize_image(image, height, width): + """Simple wrapper around tf.resize_images. + + This is primarily to make sure we use the same `ResizeMethod` and other + details each time. + + Args: + image: A 3-D image `Tensor`. + height: The target height for the resized image. + width: The target width for the resized image. + + Returns: + resized_image: A 3-D tensor containing the resized image. The first two + dimensions have the shape [height, width]. + """ + return tf.image.resize_images( + image, [height, width], method=tf.image.ResizeMethod.BILINEAR, + align_corners=False) + + +def preprocess_image(image_buffer, bbox, output_height, output_width, + num_channels, is_training=False): + """Preprocesses the given image. + + Preprocessing includes decoding, cropping, and resizing for both training + and eval images. Training preprocessing, however, introduces some random + distortion of the image to improve accuracy. + + Args: + image_buffer: scalar string Tensor representing the raw JPEG image buffer. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + output_height: The height of the image after preprocessing. + output_width: The width of the image after preprocessing. + num_channels: Integer depth of the image buffer for decoding. + is_training: `True` if we're preprocessing the image for training and + `False` otherwise. + + Returns: + A preprocessed image. + """ + if is_training: + # For training, we want to randomize some of the distortions. + image = _decode_crop_and_flip(image_buffer, bbox, num_channels) + image = _resize_image(image, output_height, output_width) + else: + # For validation, we want to decode, resize, then just crop the middle. + image = tf.image.decode_jpeg(image_buffer, channels=num_channels) + image = _aspect_preserving_resize(image, _RESIZE_MIN) + print(image) + image = _central_crop(image, output_height, output_width) + + image.set_shape([output_height, output_width, num_channels]) + + return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) + + +def parse_example_proto(example_serialized): + """Parses an Example proto containing a training example of an image. + + The output of the build_image_data.py image preprocessing script is a dataset + containing serialized Example protocol buffers. Each Example proto contains + the following fields: + + image/height: 462 + image/width: 581 + image/colorspace: 'RGB' + image/channels: 3 + image/class/label: 615 + image/class/synset: 'n03623198' + image/class/text: 'knee pad' + image/object/bbox/xmin: 0.1 + image/object/bbox/xmax: 0.9 + image/object/bbox/ymin: 0.2 + image/object/bbox/ymax: 0.6 + image/object/bbox/label: 615 + image/format: 'JPEG' + image/filename: 'ILSVRC2012_val_00041207.JPEG' + image/encoded: + + Args: + example_serialized: scalar Tensor tf.string containing a serialized + Example protocol buffer. + + Returns: + image_buffer: Tensor tf.string containing the contents of a JPEG file. + label: Tensor tf.int32 containing the label. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + text: Tensor tf.string containing the human-readable label. + """ + # Dense features in Example proto. + feature_map = { + 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, + default_value=''), + 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, + default_value=-1), + 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, + default_value=''), + } + sparse_float32 = tf.VarLenFeature(dtype=tf.float32) + # Sparse features in Example proto. + feature_map.update( + {k: sparse_float32 for k in ['image/object/bbox/xmin', + 'image/object/bbox/ymin', + 'image/object/bbox/xmax', + 'image/object/bbox/ymax']}) + + features = tf.parse_single_example(example_serialized, feature_map) + label = tf.cast(features['image/class/label'], dtype=tf.int32) + + xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) + ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) + xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) + ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) + + # Note that we impose an ordering of (y, x) just to make life difficult. + bbox = tf.concat([ymin, xmin, ymax, xmax], 0) + + # Force the variable number of bounding boxes into the shape + # [1, num_boxes, coords]. + bbox = tf.expand_dims(bbox, 0) + bbox = tf.transpose(bbox, [0, 2, 1]) + + return features['image/encoded'], label, bbox, features['image/class/text'] + + +class ImagenetPreprocessor: + def __init__(self, image_size, dtype, train): + self.image_size = image_size + self.dtype = dtype + self.train = train + + def preprocess(self, image_buffer, bbox): + # pylint: disable=g-import-not-at-top + image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train) + return tf.cast(image, self.dtype) + + def parse_and_preprocess(self, value): + image_buffer, label_index, bbox, _ = parse_example_proto(value) + image = self.preprocess(image_buffer, bbox) + image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH]) + return label_index, image + diff --git a/inception.py b/inception.py new file mode 100644 index 0000000..6c76f3b --- /dev/null +++ b/inception.py @@ -0,0 +1,105 @@ +# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import sys +import tarfile + +import numpy as np +from six.moves import urllib +import tensorflow as tf +import glob +import scipy.misc +import math +import sys + +import horovod.tensorflow as hvd + +MODEL_DIR = '/tmp/imagenet' +DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' +softmax = None + +config = tf.ConfigProto() +config.gpu_options.visible_device_list = str(hvd.local_rank()) +sess = tf.Session(config=config) + +# Call this function with list of images. Each of elements should be a +# numpy array with values ranging from 0 to 255. +def get_inception_score(images, splits=10): + # For convenience + if len(images[0].shape) != 3: + return 0, 0 + + # Bypassing all the assertions so that we don't end prematuraly' + # assert(type(images) == list) + # assert(type(images[0]) == np.ndarray) + # assert(len(images[0].shape) == 3) + # assert(np.max(images[0]) > 10) + # assert(np.min(images[0]) >= 0.0) + inps = [] + for img in images: + img = img.astype(np.float32) + inps.append(np.expand_dims(img, 0)) + bs = 1 + preds = [] + n_batches = int(math.ceil(float(len(inps)) / float(bs))) + for i in range(n_batches): + sys.stdout.write(".") + sys.stdout.flush() + inp = inps[(i * bs):min((i + 1) * bs, len(inps))] + inp = np.concatenate(inp, 0) + pred = sess.run(softmax, {'ExpandDims:0': inp}) + preds.append(pred) + preds = np.concatenate(preds, 0) + scores = [] + for i in range(splits): + part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return np.mean(scores), np.std(scores) + +# This function is called automatically. +def _init_inception(): + global softmax + if not os.path.exists(MODEL_DIR): + os.makedirs(MODEL_DIR) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(MODEL_DIR, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % ( + filename, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') + tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) + with tf.gfile.FastGFile(os.path.join( + MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, name='') + # Works with an arbitrary minibatch size. + pool3 = sess.graph.get_tensor_by_name('pool_3:0') + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.set_shape(tf.TensorShape(new_shape)) + w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] + logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) + softmax = tf.nn.softmax(logits) + +if softmax is None: + _init_inception() diff --git a/models.py b/models.py new file mode 100644 index 0000000..7099528 --- /dev/null +++ b/models.py @@ -0,0 +1,622 @@ +import tensorflow as tf +from tensorflow.python.platform import flags +import numpy as np +from utils import conv_block, get_weight, attention, conv_cond_concat, init_conv_weight, init_attention_weight, init_res_weight, smart_res_block, smart_res_block_optim, init_convt_weight +from utils import init_fc_weight, smart_conv_block, smart_fc_block, smart_atten_block, groupsort, smart_convt_block, swish + +flags.DEFINE_bool('swish_act', False, 'use the swish activation for dsprites') + +FLAGS = flags.FLAGS + + +class MnistNet(object): + def __init__(self, num_channels=1, num_filters=64): + + self.channels = num_channels + self.dim_hidden = num_filters + self.datasource = FLAGS.datasource + + if FLAGS.cclass: + self.label_size = 10 + else: + self.label_size = 0 + + def construct_weights(self, scope=''): + weights = {} + + dtype = tf.float32 + conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) + fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) + + classes = 1 + + with tf.variable_scope(scope): + init_conv_weight(weights, 'c1_pre', 3, 1, 64) + init_conv_weight(weights, 'c1', 4, 64, self.dim_hidden, classes=classes) + init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc_dense', 4*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True) + init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False) + + if FLAGS.cclass: + self.label_size = 10 + else: + self.label_size = 0 + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, **kwargs): + channels = self.channels + weights = weights.copy() + inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1)) + + if FLAGS.swish_act: + act = swish + else: + act = tf.nn.leaky_relu + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + if FLAGS.cclass: + label_d = tf.reshape(label, shape=(tf.shape(label)[0], 1, 1, self.label_size)) + inp = conv_cond_concat(inp, label_d) + + h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) + h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act) + h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act) + h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=False, extra_bias=False, activation=act) + + h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])]) + h6 = act(smart_fc_block(h5, weights, reuse, 'fc_dense')) + hidden6 = smart_fc_block(h6, weights, reuse, 'fc5') + + return hidden6 + + +class DspritesNet(object): + def __init__(self, num_channels=1, num_filters=64, cond_size=False, cond_shape=False, cond_pos=False, + cond_rot=False, label_size=1): + + self.channels = num_channels + self.dim_hidden = num_filters + self.img_size = 64 + self.label_size = label_size + + if FLAGS.cclass: + self.label_size = 3 + + try: + if FLAGS.dshape_only: + self.label_size = 3 + + if FLAGS.dpos_only: + self.label_size = 2 + + if FLAGS.dsize_only: + self.label_size = 1 + + if FLAGS.drot_only: + self.label_size = 2 + except: + pass + + if cond_size: + self.label_size = 1 + + if cond_shape: + self.label_size = 3 + + if cond_pos: + self.label_size = 2 + + if cond_rot: + self.label_size = 2 + + self.cond_size = cond_size + self.cond_shape = cond_shape + self.cond_pos = cond_pos + + def construct_weights(self, scope=''): + weights = {} + + dtype = tf.float32 + conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) + fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) + k = 5 + classes = self.label_size + + with tf.variable_scope(scope): + init_conv_weight(weights, 'c1_pre', 3, 1, 32) + init_conv_weight(weights, 'c1', 4, 32, self.dim_hidden, classes=classes) + init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_conv_weight(weights, 'c4', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc_dense', 2*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True) + init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False) + + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, return_logit=False): + channels = self.channels + batch_size = tf.shape(inp)[0] + + inp = tf.reshape(inp, (batch_size, 64, 64, 1)) + + if FLAGS.swish_act: + act = swish + else: + act = tf.nn.leaky_relu + + if not FLAGS.cclass: + label = None + + weights = weights.copy() + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) + h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act) + h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act) + h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=True, extra_bias=True, activation=act) + h5 = smart_conv_block(h4, weights, reuse, 'c4', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act) + + hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1)) + hidden7 = act(smart_fc_block(hidden6, weights, reuse, 'fc_dense')) + energy = smart_fc_block(hidden7, weights, reuse, 'fc5') + + if return_logit: + return hidden7 + else: + return energy + + + +class ResNet32(object): + def __init__(self, num_channels=3, num_filters=128): + + self.channels = num_channels + self.dim_hidden = num_filters + self.groupsort = groupsort() + + def construct_weights(self, scope=''): + weights = {} + dtype = tf.float32 + + if FLAGS.cclass: + classes = 10 + else: + classes = 1 + + with tf.variable_scope(scope): + # First block + init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden) + init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_2', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_3', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden) + init_fc_weight(weights, 'fc5', 2*self.dim_hidden , 1, spec_norm=False) + + init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True) + + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): + weights = weights.copy() + batch = tf.shape(inp)[0] + + act = tf.nn.leaky_relu + + if not FLAGS.cclass: + label = None + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + # Make sure gradients are modified a bit + inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False) + + hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act) + hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, act=act) + hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, label=label, act=act) + + if FLAGS.use_attention: + hidden4 = smart_atten_block(hidden3, weights, reuse, 'atten', stop_at_grad=stop_at_grad, label=label) + else: + hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act) + + hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=False, label=label, act=act) + compact = hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + hidden6 = tf.nn.relu(hidden6) + hidden5 = tf.reduce_sum(hidden6, [1, 2]) + + hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') + + energy = hidden6 + + return energy + + +class ResNet32Large(object): + def __init__(self, num_channels=3, num_filters=128, train=False): + + self.channels = num_channels + self.dim_hidden = num_filters + self.dropout = train + self.train = train + + def construct_weights(self, scope=''): + weights = {} + dtype = tf.float32 + + if FLAGS.cclass: + classes = 10 + else: + classes = 1 + + with tf.variable_scope(scope): + # First block + init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden) + init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False) + + init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden, trainable_gamma=True) + + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): + weights = weights.copy() + batch = tf.shape(inp)[0] + + if not FLAGS.cclass: + label = None + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + # Make sure gradients are modified a bit + inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False) + + dropout = self.dropout + train = self.train + + hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, dropout=dropout, train=train) + hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train) + hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train) + hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train) + + if FLAGS.use_attention: + hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad) + else: + hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) + + hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) + + hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train) + hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) + + compact = hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) + + if FLAGS.cclass: + hidden6 = tf.nn.leaky_relu(hidden9) + else: + hidden6 = tf.nn.relu(hidden9) + hidden5 = tf.reduce_sum(hidden6, [1, 2]) + + hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') + + energy = hidden6 + + return energy + + +class ResNet32Wider(object): + def __init__(self, num_channels=3, num_filters=128, train=False): + + self.channels = num_channels + self.dim_hidden = num_filters + self.dropout = train + self.train = train + + def construct_weights(self, scope=''): + weights = {} + dtype = tf.float32 + + if FLAGS.cclass and FLAGS.dataset == "cifar10": + classes = 10 + elif FLAGS.cclass and FLAGS.dataset == "imagenet": + classes = 1000 + else: + classes = 1 + + with tf.variable_scope(scope): + # First block + init_conv_weight(weights, 'c1_pre', 3, self.channels, 128) + init_res_weight(weights, 'res_optim', 3, 128, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False) + + init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True) + + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): + weights = weights.copy() + batch = tf.shape(inp)[0] + + if not FLAGS.cclass: + label = None + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + if FLAGS.swish_act: + act = swish + else: + act = tf.nn.leaky_relu + + # Make sure gradients are modified a bit + inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) + dropout = self.dropout + train = self.train + + hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=True, label=label, dropout=dropout, train=train) + + if FLAGS.use_attention: + hidden2 = smart_atten_block(hidden1, weights, reuse, 'atten', train=train, dropout=dropout, stop_at_grad=stop_at_grad) + else: + hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act) + + hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act) + hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) + + hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) + + hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) + + hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) + hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) + + hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) + + if FLAGS.swish_act: + hidden6 = act(hidden9) + else: + hidden6 = tf.nn.relu(hidden9) + + hidden5 = tf.reduce_sum(hidden6, [1, 2]) + hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') + energy = hidden6 + + return energy + + +class ResNet32Larger(object): + def __init__(self, num_channels=3, num_filters=128): + + self.channels = num_channels + self.dim_hidden = num_filters + + def construct_weights(self, scope=''): + weights = {} + dtype = tf.float32 + + if FLAGS.cclass: + classes = 10 + else: + classes = 1 + + with tf.variable_scope(scope): + # First block + init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden) + init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_2a', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_2b', 3, self.dim_hidden, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5a', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5b', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_8a', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_8b', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden) + init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False) + + init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True) + + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): + weights = weights.copy() + batch = tf.shape(inp)[0] + + if not FLAGS.cclass: + label = None + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + # Make sure gradients are modified a bit + inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False) + + hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label) + hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) + hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) + hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2a', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) + hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2b', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) + hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label) + + if FLAGS.use_attention: + hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad) + else: + hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + + hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + + hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label) + hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + compact = hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) + + if FLAGS.cclass: + hidden6 = tf.nn.leaky_relu(hidden9) + else: + hidden6 = tf.nn.relu(hidden9) + hidden5 = tf.reduce_sum(hidden6, [1, 2]) + + hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') + + energy = hidden6 + + return energy + + +class ResNet128(object): + """Construct the convolutional network specified in MAML""" + + def __init__(self, num_channels=3, num_filters=64, train=False): + + self.channels = num_channels + self.dim_hidden = num_filters + self.dropout = train + self.train = train + + def construct_weights(self, scope=''): + weights = {} + dtype = tf.float32 + + classes = 1000 + + with tf.variable_scope(scope): + # First block + init_conv_weight(weights, 'c1_pre', 3, self.channels, 64) + init_res_weight(weights, 'res_optim', 3, 64, self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 8*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_9', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes) + init_res_weight(weights, 'res_10', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes) + init_fc_weight(weights, 'fc5', 8*self.dim_hidden , 1, spec_norm=False) + + + init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2., trainable_gamma=True) + + return weights + + def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): + weights = weights.copy() + batch = tf.shape(inp)[0] + + if not FLAGS.cclass: + label = None + + + if stop_grad: + for k, v in weights.items(): + if type(v) == dict: + v = v.copy() + weights[k] = v + for k_sub, v_sub in v.items(): + v[k_sub] = tf.stop_gradient(v_sub) + else: + weights[k] = tf.stop_gradient(v) + + if FLAGS.swish_act: + act = swish + else: + act = tf.nn.leaky_relu + + dropout = self.dropout + train = self.train + + # Make sure gradients are modified a bit + inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) + hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False) + + if FLAGS.use_attention: + hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad) + + hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act) + hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act) + hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True) + hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False) + hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False) + + if FLAGS.swish_act: + hidden6 = act(hidden6) + else: + hidden6 = tf.nn.relu(hidden6) + + hidden5 = tf.reduce_sum(hidden6, [1, 2]) + hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') + energy = hidden6 + + return energy diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4573f37 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +scipy==1.1.0 +horovod==0.16.0 +torch==1.5.0 +torchvision==0.6.0 +six==1.11.0 +imageio==2.8.0 +tqdm==4.46.0 +matplotlib==3.2.1 +mpi4py==3.0.3 +numpy==1.18.4 +Pillow==5.4.1 +baselines==0.1.5 +scikit-image==0.14.2 +scikit_learn +tensorflow==1.13.1 +cloudpickle==1.3.0 +Cython==0.29.17 \ No newline at end of file diff --git a/test_inception.py b/test_inception.py new file mode 100644 index 0000000..ca7e55b --- /dev/null +++ b/test_inception.py @@ -0,0 +1,333 @@ +import tensorflow as tf +import numpy as np +from tensorflow.python.platform import flags +from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128 +import os.path as osp +import os +from utils import optimistic_restore, remap_restore, optimistic_remap_restore +from tqdm import tqdm +import random +from scipy.misc import imsave +from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader +from torch.utils.data import DataLoader +from baselines.common.tf_util import initialize + +import horovod.tensorflow as hvd +hvd.init() + +from inception import get_inception_score +from fid import get_fid_score + +flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') +flags.DEFINE_string('exp', 'default', 'name of experiments') +flags.DEFINE_bool('cclass', False, 'whether to condition on class') + +# Architecture settings +flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') +flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') +flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') +flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') +flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent') +flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') +flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images') +flags.DEFINE_integer('batch_size', 512, 'batch size') +flags.DEFINE_integer('resume_iter', -1, 'resume iteration') +flags.DEFINE_integer('ensemble', 10, 'number of ensembles') +flags.DEFINE_integer('im_number', 50000, 'number of ensembles') +flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations') +flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output') +flags.DEFINE_integer('idx', 0, 'save index') +flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing') +flags.DEFINE_bool('scaled', True, 'whether to scale noise added') +flags.DEFINE_bool('large_model', False, 'whether to use a small or large model') +flags.DEFINE_bool('larger_model', False, 'Whether to use a large model') +flags.DEFINE_bool('wider_model', False, 'Whether to use a large model') +flags.DEFINE_bool('single', False, 'single ') +flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') +flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull') + +FLAGS = flags.FLAGS + +class InceptionReplayBuffer(object): + def __init__(self, size): + """Create Replay buffer. + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + """ + self._storage = [] + self._label_storage = [] + self._maxsize = size + self._next_idx = 0 + + def __len__(self): + return len(self._storage) + + def add(self, ims, labels): + batch_size = ims.shape[0] + if self._next_idx >= len(self._storage): + self._storage.extend(list(ims)) + self._label_storage.extend(list(labels)) + else: + if batch_size + self._next_idx < self._maxsize: + self._storage[self._next_idx:self._next_idx+batch_size] = list(ims) + self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels) + else: + split_idx = self._maxsize - self._next_idx + self._storage[self._next_idx:] = list(ims)[:split_idx] + self._storage[:batch_size-split_idx] = list(ims)[split_idx:] + self._label_storage[self._next_idx:] = list(labels)[:split_idx] + self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:] + + self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize + + def _encode_sample(self, idxes): + ims = [] + labels = [] + for i in idxes: + ims.append(self._storage[i]) + labels.append(self._label_storage[i]) + return np.array(ims), np.array(labels) + + def sample(self, batch_size): + """Sample a batch of experiences. + Parameters + ---------- + batch_size: int + How many transitions to sample. + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + """ + idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] + return self._encode_sample(idxes), idxes + + def set_elms(self, idxes, data, labels): + for i, ix in enumerate(idxes): + self._storage[ix] = data[i] + self._label_storage[ix] = labels[i] + + +def rescale_im(im): + return np.clip(im * 256, 0, 255).astype(np.uint8) + +def compute_inception(sess, target_vars): + X_START = target_vars['X_START'] + Y_GT = target_vars['Y_GT'] + X_finals = target_vars['X_finals'] + NOISE_SCALE = target_vars['NOISE_SCALE'] + energy_noise = target_vars['energy_noise'] + + size = FLAGS.im_number + num_steps = size // 1000 + + images = [] + test_ims = [] + + + if FLAGS.dataset == "cifar10": + test_dataset = Cifar10(full=True, noise=False) + elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull": + test_dataset = Imagenet(train=False) + + if FLAGS.dataset != "imagenetfull": + test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) + else: + test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1) + + for data_corrupt, data, label_gt in tqdm(test_dataloader): + data = data.numpy() + test_ims.extend(list(rescale_im(data))) + + if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000: + test_ims = test_ims[:60000] + break + + + # n = min(len(images), len(test_ims)) + print(len(test_ims)) + # fid = get_fid_score(test_ims[:30000], test_ims[-30000:]) + # print("Base FID of score {}".format(fid)) + + if FLAGS.dataset == "cifar10": + classes = 10 + else: + classes = 1000 + + if FLAGS.dataset == "imagenetfull": + n = 128 + else: + n = 32 + + for j in range(num_steps): + itr = int(1000 / 500 * FLAGS.repeat_scale) + data_buffer = InceptionReplayBuffer(1000) + curr_index = 0 + + identity = np.eye(classes) + + for i in tqdm(range(itr)): + model_index = curr_index % len(X_finals) + x_final = X_finals[model_index] + + noise_scale = [1] + if len(data_buffer) < 1000: + x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) + label = np.random.randint(0, classes, (FLAGS.batch_size)) + label = identity[label] + x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0] + data_buffer.add(x_new, label) + else: + (x_init, label), idx = data_buffer.sample(FLAGS.batch_size) + keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99) + label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9) + label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size)) + label_corrupt = identity[label_corrupt] + x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) + + if i < itr - FLAGS.nomix: + x_init[keep_mask] = x_init_corrupt[keep_mask] + label[label_keep_mask] = label_corrupt[label_keep_mask] + # else: + # noise_scale = [0.7] + + x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale}) + data_buffer.set_elms(idx, x_new, label) + + if FLAGS.im_number != 50000: + print(np.mean(e_noise), np.std(e_noise)) + + curr_index += 1 + + ims = np.array(data_buffer._storage[:1000]) + ims = rescale_im(ims) + + images.extend(list(ims)) + + saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx)) + + ims = ims[:100] + + if FLAGS.dataset != "imagenetfull": + im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3)) + else: + im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3)) + imsave(saveim, im_panel) + + print("Saved image!!!!") + splits = max(1, len(images) // 5000) + score, std = get_inception_score(images, splits=splits) + print("Inception score of {} with std of {}".format(score, std)) + + # FID score + # n = min(len(images), len(test_ims)) + fid = get_fid_score(images, test_ims) + print("FID of score {}".format(fid)) + + + + +def main(model_list): + + if FLAGS.dataset == "imagenetfull": + model = ResNet128(num_filters=64) + elif FLAGS.large_model: + model = ResNet32Large(num_filters=128) + elif FLAGS.larger_model: + model = ResNet32Larger(num_filters=hidden_dim) + elif FLAGS.wider_model: + model = ResNet32Wider(num_filters=256, train=False) + else: + model = ResNet32(num_filters=128) + + # config = tf.ConfigProto() + sess = tf.InteractiveSession() + + logdir = osp.join(FLAGS.logdir, FLAGS.exp) + weights = [] + + for i, model_num in enumerate(model_list): + weight = model.construct_weights('context_{}'.format(i)) + initialize() + save_file = osp.join(logdir, 'model_{}'.format(model_num)) + + v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i)) + v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list} + saver = tf.train.Saver(v_map) + try: + saver.restore(sess, save_file) + except: + optimistic_remap_restore(sess, save_file, i) + weights.append(weight) + + + if FLAGS.dataset == "imagenetfull": + X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32) + else: + X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32) + + if FLAGS.dataset == "cifar10": + Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32) + else: + Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32) + + NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32) + + X_finals = [] + + + # Seperate loops + for weight in weights: + X = X_START + + steps = tf.constant(0) + c = lambda i, x: tf.less(i, FLAGS.num_steps) + def langevin_step(counter, X): + scale_rate = 1 + + X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE) + + energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) + x_grad = tf.gradients(energy_noise, [X])[0] + + if FLAGS.proj_norm != 0.0: + x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) + + X = X - FLAGS.step_lr * x_grad * scale_rate + X = tf.clip_by_value(X, 0, 1) + + counter = counter + 1 + + return counter, X + + steps, X = tf.while_loop(c, langevin_step, (steps, X)) + energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) + X_final = X + X_finals.append(X_final) + + target_vars = {} + target_vars['X_START'] = X_START + target_vars['Y_GT'] = Y_GT + target_vars['X_finals'] = X_finals + target_vars['NOISE_SCALE'] = NOISE_SCALE + target_vars['energy_noise'] = energy_noise + + compute_inception(sess, target_vars) + + +if __name__ == "__main__": + # model_list = [117000, 116700] + model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)] + main(model_list) diff --git a/train.py b/train.py new file mode 100644 index 0000000..716c4ea --- /dev/null +++ b/train.py @@ -0,0 +1,941 @@ +import tensorflow as tf +import numpy as np +from tensorflow.python.platform import flags + +from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader +from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128 +import os.path as osp +import os +from baselines.logger import TensorBoardOutputFormat +from utils import average_gradients, ReplayBuffer, optimistic_restore +from tqdm import tqdm +import random +from torch.utils.data import DataLoader +import time as time +from io import StringIO +from tensorflow.core.util import event_pb2 +import torch +import numpy as np +from custom_adam import AdamOptimizer +from scipy.misc import imsave +import matplotlib.pyplot as plt +from hmc import hmc + +from mpi4py import MPI +comm = MPI.COMM_WORLD +rank = comm.Get_rank() + +import horovod.tensorflow as hvd +hvd.init() + +from inception import get_inception_score + +torch.manual_seed(hvd.rank()) +np.random.seed(hvd.rank()) +tf.set_random_seed(hvd.rank()) + +FLAGS = flags.FLAGS + + +# Dataset Options +flags.DEFINE_string('datasource', 'random', + 'initialization for chains, either random or default (decorruption)') +flags.DEFINE_string('dataset','mnist', + 'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)') +flags.DEFINE_integer('batch_size', 256, 'Size of inputs') +flags.DEFINE_bool('single', False, 'whether to debug by training on a single image') +flags.DEFINE_integer('data_workers', 4, + 'Number of different data workers to load data in parallel') + +# General Experiment Settings +flags.DEFINE_string('logdir', 'cachedir', + 'location where log of experiments will be stored') +flags.DEFINE_string('exp', 'default', 'name of experiments') +flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches') +flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches') +flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches') +flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from') +flags.DEFINE_bool('train', True, 'whether to train or test') +flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on') +flags.DEFINE_float('lr', 3e-4, 'Learning for training') +flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on') + +# EBM Specific Experiments Settings +flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients') +flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training') +flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models') +flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels') +flags.DEFINE_integer('temperature', 1, 'Temperature for energy function') +flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),' + 'logsumexp(more stable)' + 'softplus(most stable)') +flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss') + +# Setting for MCMC sampling +flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images') +flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection') +flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training') +flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent') +flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.') +flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models') +flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC') +flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead') + +# Architecture Settings +flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets') +flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') +flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') +flags.DEFINE_bool('large_model', False, 'whether to use a large model') +flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network') +flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network') + +# Dataset settings +flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images') +flags.DEFINE_bool('augment', False, 'whether to augmentations to images') +flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box') + +# Dsprites specific experiments +flags.DEFINE_bool('cond_shape', False, 'condition of shape type') +flags.DEFINE_bool('cond_size', False, 'condition of shape size') +flags.DEFINE_bool('cond_pos', False, 'condition of position loc') +flags.DEFINE_bool('cond_rot', False, 'condition of rot') + +FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale + +FLAGS.batch_size *= FLAGS.num_gpus + +print("{} batch size".format(FLAGS.batch_size)) + + +def compress_x_mod(x_mod): + x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8) + return x_mod + + +def decompress_x_mod(x_mod): + x_mod = x_mod / 256 * FLAGS.rescale + \ + np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape) + return x_mod + + +def make_image(tensor): + """Convert an numpy representation image to Image protobuf""" + from PIL import Image + if len(tensor.shape) == 4: + _, height, width, channel = tensor.shape + elif len(tensor.shape) == 3: + height, width, channel = tensor.shape + elif len(tensor.shape) == 2: + height, width = tensor.shape + channel = 1 + tensor = tensor.astype(np.uint8) + image = Image.fromarray(tensor) + import io + output = io.BytesIO() + image.save(output, format='PNG') + image_string = output.getvalue() + output.close() + return tf.Summary.Image(height=height, + width=width, + colorspace=channel, + encoded_image_string=image_string) + + +def log_image(im, logger, tag, step=0): + im = make_image(im) + + summary = [tf.Summary.Value(tag=tag, image=im)] + summary = tf.Summary(value=summary) + event = event_pb2.Event(summary=summary) + event.step = step + logger.writer.WriteEvent(event) + logger.writer.Flush() + + +def rescale_im(image): + image = np.clip(image, 0, FLAGS.rescale) + if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites': + return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8) + else: + return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8) + + +def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir): + X = target_vars['X'] + Y = target_vars['Y'] + X_NOISE = target_vars['X_NOISE'] + train_op = target_vars['train_op'] + energy_pos = target_vars['energy_pos'] + energy_neg = target_vars['energy_neg'] + loss_energy = target_vars['loss_energy'] + loss_ml = target_vars['loss_ml'] + loss_total = target_vars['total_loss'] + gvs = target_vars['gvs'] + x_grad = target_vars['x_grad'] + x_grad_first = target_vars['x_grad_first'] + x_off = target_vars['x_off'] + temp = target_vars['temp'] + x_mod = target_vars['x_mod'] + LABEL = target_vars['LABEL'] + LABEL_POS = target_vars['LABEL_POS'] + weights = target_vars['weights'] + test_x_mod = target_vars['test_x_mod'] + eps = target_vars['eps_begin'] + label_ent = target_vars['label_ent'] + + if FLAGS.use_attention: + gamma = weights[0]['atten']['gamma'] + else: + gamma = tf.zeros(1) + + val_output = [test_x_mod] + + gvs_dict = dict(gvs) + + log_output = [ + train_op, + energy_pos, + energy_neg, + eps, + loss_energy, + loss_ml, + loss_total, + x_grad, + x_off, + x_mod, + gamma, + x_grad_first, + label_ent, + *gvs_dict.keys()] + output = [train_op, x_mod] + + replay_buffer = ReplayBuffer(10000) + itr = resume_iter + x_mod = None + gd_steps = 1 + + dataloader_iterator = iter(dataloader) + best_inception = 0.0 + + for epoch in range(FLAGS.epoch_num): + for data_corrupt, data, label in dataloader: + data_corrupt = data_corrupt_init = data_corrupt.numpy() + data_corrupt_init = data_corrupt.copy() + + data = data.numpy() + label = label.numpy() + + label_init = label.copy() + + if FLAGS.mixup: + idx = np.random.permutation(data.shape[0]) + lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1)) + data = data * lam + data[idx] * (1 - lam) + + if FLAGS.replay_batch and (x_mod is not None): + replay_buffer.add(compress_x_mod(x_mod)) + + if len(replay_buffer) > FLAGS.batch_size: + replay_batch = replay_buffer.sample(FLAGS.batch_size) + replay_batch = decompress_x_mod(replay_batch) + replay_mask = ( + np.random.uniform( + 0, + FLAGS.rescale, + FLAGS.batch_size) > 0.05) + data_corrupt[replay_mask] = replay_batch[replay_mask] + + if FLAGS.pcd: + if x_mod is not None: + data_corrupt = x_mod + + feed_dict = {X_NOISE: data_corrupt, X: data, Y: label} + + if FLAGS.cclass: + feed_dict[LABEL] = label + feed_dict[LABEL_POS] = label_init + + if itr % FLAGS.log_interval == 0: + _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \ + grads = sess.run(log_output, feed_dict) + + kvs = {} + kvs['e_pos'] = e_pos.mean() + kvs['e_pos_std'] = e_pos.std() + kvs['e_neg'] = e_neg.mean() + kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg'] + kvs['e_neg_std'] = e_neg.std() + kvs['temp'] = temp + kvs['loss_e'] = loss_e.mean() + kvs['eps'] = eps.mean() + kvs['label_ent'] = label_ent + kvs['loss_ml'] = loss_ml.mean() + kvs['loss_total'] = loss_total.mean() + kvs['x_grad'] = np.abs(x_grad).mean() + kvs['x_grad_first'] = np.abs(x_grad_first).mean() + kvs['x_off'] = x_off.mean() + kvs['iter'] = itr + kvs['gamma'] = gamma + + for v, k in zip(grads, [v.name for v in gvs_dict.values()]): + kvs[k] = np.abs(v).max() + + string = "Obtained a total of " + for key, value in kvs.items(): + string += "{}: {}, ".format(key, value) + + if hvd.rank() == 0: + print(string) + logger.writekvs(kvs) + else: + _, x_mod = sess.run(output, feed_dict) + + if itr % FLAGS.save_interval == 0 and hvd.rank() == 0: + saver.save( + sess, + osp.join( + FLAGS.logdir, + FLAGS.exp, + 'model_{}'.format(itr))) + + if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d': + try_im = x_mod + orig_im = data_corrupt.squeeze() + actual_im = rescale_im(data) + + orig_im = rescale_im(orig_im) + try_im = rescale_im(try_im).squeeze() + + for i, (im, t_im, actual_im_i) in enumerate( + zip(orig_im[:20], try_im[:20], actual_im)): + shape = orig_im.shape[1:] + new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) + size = shape[1] + new_im[:, :size] = im + new_im[:, size:2 * size] = t_im + new_im[:, 2 * size:] = actual_im_i + + log_image( + new_im, logger, 'train_gen_{}'.format(itr), step=i) + + test_im = x_mod + + try: + data_corrupt, data, label = next(dataloader_iterator) + except BaseException: + dataloader_iterator = iter(dataloader) + data_corrupt, data, label = next(dataloader_iterator) + + data_corrupt = data_corrupt.numpy() + + if FLAGS.replay_batch and ( + x_mod is not None) and len(replay_buffer) > 0: + replay_batch = replay_buffer.sample(FLAGS.batch_size) + replay_batch = decompress_x_mod(replay_batch) + replay_mask = ( + np.random.uniform( + 0, 1, (FLAGS.batch_size)) > 0.05) + data_corrupt[replay_mask] = replay_batch[replay_mask] + + if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull': + n = 128 + + if FLAGS.dataset == "imagenetfull": + n = 32 + + if len(replay_buffer) > n: + data_corrupt = decompress_x_mod(replay_buffer.sample(n)) + elif FLAGS.dataset == 'imagenetfull': + data_corrupt = np.random.uniform( + 0, FLAGS.rescale, (n, 128, 128, 3)) + else: + data_corrupt = np.random.uniform( + 0, FLAGS.rescale, (n, 32, 32, 3)) + + if FLAGS.dataset == 'cifar10': + label = np.eye(10)[np.random.randint(0, 10, (n))] + else: + label = np.eye(1000)[ + np.random.randint( + 0, 1000, (n))] + + feed_dict[X_NOISE] = data_corrupt + + feed_dict[X] = data + + if FLAGS.cclass: + feed_dict[LABEL] = label + + test_x_mod = sess.run(val_output, feed_dict) + + try_im = test_x_mod + orig_im = data_corrupt.squeeze() + actual_im = rescale_im(data.numpy()) + + orig_im = rescale_im(orig_im) + try_im = rescale_im(try_im).squeeze() + + for i, (im, t_im, actual_im_i) in enumerate( + zip(orig_im[:20], try_im[:20], actual_im)): + + shape = orig_im.shape[1:] + new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) + size = shape[1] + new_im[:, :size] = im + new_im[:, size:2 * size] = t_im + new_im[:, 2 * size:] = actual_im_i + log_image( + new_im, logger, 'val_gen_{}'.format(itr), step=i) + + score, std = get_inception_score(list(try_im), splits=1) + print( + "Inception score of {} with std of {}".format( + score, std)) + kvs = {} + kvs['inception_score'] = score + kvs['inception_score_std'] = std + logger.writekvs(kvs) + + if score > best_inception: + best_inception = score + saver.save( + sess, + osp.join( + FLAGS.logdir, + FLAGS.exp, + 'model_best')) + + if itr > 60000 and FLAGS.dataset == "mnist": + assert False + itr += 1 + + saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) + + +cifar10_map = {0: 'airplane', + 1: 'automobile', + 2: 'bird', + 3: 'cat', + 4: 'deer', + 5: 'dog', + 6: 'frog', + 7: 'horse', + 8: 'ship', + 9: 'truck'} + + +def test(target_vars, saver, sess, logger, dataloader): + X_NOISE = target_vars['X_NOISE'] + X = target_vars['X'] + Y = target_vars['Y'] + LABEL = target_vars['LABEL'] + energy_start = target_vars['energy_start'] + x_mod = target_vars['x_mod'] + x_mod = target_vars['test_x_mod'] + energy_neg = target_vars['energy_neg'] + + np.random.seed(1) + random.seed(1) + + output = [x_mod, energy_start, energy_neg] + + dataloader_iterator = iter(dataloader) + data_corrupt, data, label = next(dataloader_iterator) + data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy() + + orig_im = try_im = data_corrupt + + if FLAGS.cclass: + try_im, energy_orig, energy = sess.run( + output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label}) + else: + try_im, energy_orig, energy = sess.run( + output, {X_NOISE: orig_im, Y: label[0:1]}) + + orig_im = rescale_im(orig_im) + try_im = rescale_im(try_im) + actual_im = rescale_im(data) + + for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate( + zip(orig_im, energy_orig, try_im, energy, label, actual_im)): + label_i = np.array(label_i) + + shape = im.shape[1:] + new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) + size = shape[1] + new_im[:, :size] = im + new_im[:, size:2 * size] = t_im + + if FLAGS.cclass: + label_i = np.where(label_i == 1)[0][0] + if FLAGS.dataset == 'cifar10': + log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format( + i, energy_i[0], energy[0], cifar10_map[label_i]), step=i) + else: + log_image( + new_im, + logger, + '{}_{:.4f}_now_{:.4f}_{}'.format( + i, + energy_i[0], + energy[0], + label_i), + step=i) + else: + log_image( + new_im, + logger, + '{}_{:.4f}_now_{:.4f}'.format( + i, + energy_i[0], + energy[0]), + step=i) + + test_ims = list(try_im) + real_ims = list(actual_im) + + for i in tqdm(range(50000 // FLAGS.batch_size + 1)): + try: + data_corrupt, data, label = dataloader_iterator.next() + except BaseException: + dataloader_iterator = iter(dataloader) + data_corrupt, data, label = dataloader_iterator.next() + + data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy() + + if FLAGS.cclass: + try_im, energy_orig, energy = sess.run( + output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label}) + else: + try_im, energy_orig, energy = sess.run( + output, {X_NOISE: data_corrupt, Y: label[0:1]}) + + try_im = rescale_im(try_im) + real_im = rescale_im(data) + + test_ims.extend(list(try_im)) + real_ims.extend(list(real_im)) + + score, std = get_inception_score(test_ims) + print("Inception score of {} with std of {}".format(score, std)) + + +def main(): + print("Local rank: ", hvd.local_rank(), hvd.size()) + + logdir = osp.join(FLAGS.logdir, FLAGS.exp) + if hvd.rank() == 0: + if not osp.exists(logdir): + os.makedirs(logdir) + logger = TensorBoardOutputFormat(logdir) + else: + logger = None + + LABEL = None + print("Loading data...") + if FLAGS.dataset == 'cifar10': + dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) + test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) + channel_num = 3 + + X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) + X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) + LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) + + if FLAGS.large_model: + model = ResNet32Large( + num_channels=channel_num, + num_filters=128, + train=True) + elif FLAGS.larger_model: + model = ResNet32Larger( + num_channels=channel_num, + num_filters=128) + elif FLAGS.wider_model: + model = ResNet32Wider( + num_channels=channel_num, + num_filters=192) + else: + model = ResNet32( + num_channels=channel_num, + num_filters=128) + + elif FLAGS.dataset == 'imagenet': + dataset = Imagenet(train=True) + test_dataset = Imagenet(train=False) + channel_num = 3 + X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) + X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) + LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) + + model = ResNet32Wider( + num_channels=channel_num, + num_filters=256) + + elif FLAGS.dataset == 'imagenetfull': + channel_num = 3 + X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) + X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) + LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) + + model = ResNet128( + num_channels=channel_num, + num_filters=64) + + elif FLAGS.dataset == 'mnist': + dataset = Mnist(rescale=FLAGS.rescale) + test_dataset = dataset + channel_num = 1 + X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) + X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) + LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) + + model = MnistNet( + num_channels=channel_num, + num_filters=FLAGS.num_filters) + + elif FLAGS.dataset == 'dsprites': + dataset = DSprites( + cond_shape=FLAGS.cond_shape, + cond_size=FLAGS.cond_size, + cond_pos=FLAGS.cond_pos, + cond_rot=FLAGS.cond_rot) + test_dataset = dataset + channel_num = 1 + + X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) + X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) + + if FLAGS.dpos_only: + LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) + elif FLAGS.dsize_only: + LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) + elif FLAGS.drot_only: + LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) + elif FLAGS.cond_size: + LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) + elif FLAGS.cond_shape: + LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) + elif FLAGS.cond_pos: + LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) + elif FLAGS.cond_rot: + LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) + else: + LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) + LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) + + model = DspritesNet( + num_channels=channel_num, + num_filters=FLAGS.num_filters, + cond_size=FLAGS.cond_size, + cond_shape=FLAGS.cond_shape, + cond_pos=FLAGS.cond_pos, + cond_rot=FLAGS.cond_rot) + + print("Done loading...") + + if FLAGS.dataset == "imagenetfull": + # In the case of full imagenet, use custom_tensorflow dataloader + data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale) + else: + data_loader = DataLoader( + dataset, + batch_size=FLAGS.batch_size, + num_workers=FLAGS.data_workers, + drop_last=True, + shuffle=True) + + batch_size = FLAGS.batch_size + + weights = [model.construct_weights('context_0')] + + Y = tf.placeholder(shape=(None), dtype=tf.int32) + + # Varibles to run in training + X_SPLIT = tf.split(X, FLAGS.num_gpus) + X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) + LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) + LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) + LABEL_SPLIT_INIT = list(LABEL_SPLIT) + tower_grads = [] + tower_gen_grads = [] + x_mod_list = [] + + optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) + optimizer = hvd.DistributedOptimizer(optimizer) + + for j in range(FLAGS.num_gpus): + + if FLAGS.model_cclass: + ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus + label_tensor = tf.Variable( + tf.convert_to_tensor( + np.reshape( + np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), + (FLAGS.batch_size * 10, 10)), + dtype=tf.float32), + trainable=False, + dtype=tf.float32) + x_split = tf.tile( + tf.reshape( + X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) + x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) + energy_pos = model.forward( + x_split, + weights[0], + label=label_tensor, + stop_at_grad=False) + + energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) + energy_partition_est = tf.reduce_logsumexp( + energy_pos_full, axis=1, keepdims=True) + uniform = tf.random_uniform(tf.shape(energy_pos_full)) + label_tensor = tf.argmax(-energy_pos_full - + tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) + label = tf.one_hot(label_tensor, 10, dtype=tf.float32) + label = tf.Print(label, [label_tensor, energy_pos_full]) + LABEL_SPLIT[j] = label + energy_pos = tf.concat(energy_pos, axis=0) + else: + energy_pos = [ + model.forward( + X_SPLIT[j], + weights[0], + label=LABEL_POS_SPLIT[j], + stop_at_grad=False)] + energy_pos = tf.concat(energy_pos, axis=0) + + print("Building graph...") + x_mod = x_orig = X_NOISE_SPLIT[j] + + x_grads = [] + + energy_negs = [] + loss_energys = [] + + energy_negs.extend([model.forward(tf.stop_gradient( + x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)]) + eps_begin = tf.zeros(1) + + steps = tf.constant(0) + c = lambda i, x: tf.less(i, FLAGS.num_steps) + + def langevin_step(counter, x_mod): + x_mod = x_mod + tf.random_normal(tf.shape(x_mod), + mean=0.0, + stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) + + energy_noise = energy_start = tf.concat( + [model.forward( + x_mod, + weights[0], + label=LABEL_SPLIT[j], + reuse=True, + stop_at_grad=False, + stop_batch=True)], + axis=0) + + x_grad, label_grad = tf.gradients( + FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) + energy_noise_old = energy_noise + + lr = FLAGS.step_lr + + if FLAGS.proj_norm != 0.0: + if FLAGS.proj_norm_type == 'l2': + x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) + elif FLAGS.proj_norm_type == 'li': + x_grad = tf.clip_by_value( + x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) + else: + print("Other types of projection are not supported!!!") + assert False + + # Clip gradient norm for now + if FLAGS.hmc: + # Step size should be tuned to get around 65% acceptance + def energy(x): + return FLAGS.temperature * \ + model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) + + x_last = hmc(x_mod, 15., 10, energy) + else: + x_last = x_mod - (lr) * x_grad + + x_mod = x_last + x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) + + counter = counter + 1 + + return counter, x_mod + + steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) + + energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], + stop_at_grad=False, reuse=True) + x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] + x_grads.append(x_grad) + + energy_negs.append( + model.forward( + tf.stop_gradient(x_mod), + weights[0], + label=LABEL_SPLIT[j], + stop_at_grad=False, + reuse=True)) + + test_x_mod = x_mod + + temp = FLAGS.temperature + + energy_neg = energy_negs[-1] + x_off = tf.reduce_mean( + tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) + + loss_energy = model.forward( + x_mod, + weights[0], + reuse=True, + label=LABEL, + stop_grad=True) + + print("Finished processing loop construction ...") + + target_vars = {} + + if FLAGS.cclass or FLAGS.model_cclass: + label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) + label_prob = label_sum / tf.reduce_sum(label_sum) + label_ent = -tf.reduce_sum(label_prob * + tf.math.log(label_prob + 1e-7)) + else: + label_ent = tf.zeros(1) + + target_vars['label_ent'] = label_ent + + if FLAGS.train: + + if FLAGS.objective == 'logsumexp': + pos_term = temp * energy_pos + energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) + coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) + norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 + pos_loss = tf.reduce_mean(temp * energy_pos) + neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant + loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) + elif FLAGS.objective == 'cd': + pos_loss = tf.reduce_mean(temp * energy_pos) + neg_loss = -tf.reduce_mean(temp * energy_neg) + loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) + elif FLAGS.objective == 'softplus': + loss_ml = FLAGS.ml_coeff * \ + tf.nn.softplus(temp * (energy_pos - energy_neg)) + + loss_total = tf.reduce_mean(loss_ml) + + if not FLAGS.zero_kl: + loss_total = loss_total + tf.reduce_mean(loss_energy) + + loss_total = loss_total + \ + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) + + print("Started gradient computation...") + gvs = optimizer.compute_gradients(loss_total) + gvs = [(k, v) for (k, v) in gvs if k is not None] + + print("Applying gradients...") + + tower_grads.append(gvs) + + print("Finished applying gradients.") + + target_vars['loss_ml'] = loss_ml + target_vars['total_loss'] = loss_total + target_vars['loss_energy'] = loss_energy + target_vars['weights'] = weights + target_vars['gvs'] = gvs + + target_vars['X'] = X + target_vars['Y'] = Y + target_vars['LABEL'] = LABEL + target_vars['LABEL_POS'] = LABEL_POS + target_vars['X_NOISE'] = X_NOISE + target_vars['energy_pos'] = energy_pos + target_vars['energy_start'] = energy_negs[0] + + if len(x_grads) >= 1: + target_vars['x_grad'] = x_grads[-1] + target_vars['x_grad_first'] = x_grads[0] + else: + target_vars['x_grad'] = tf.zeros(1) + target_vars['x_grad_first'] = tf.zeros(1) + + target_vars['x_mod'] = x_mod + target_vars['x_off'] = x_off + target_vars['temp'] = temp + target_vars['energy_neg'] = energy_neg + target_vars['test_x_mod'] = test_x_mod + target_vars['eps_begin'] = eps_begin + + if FLAGS.train: + grads = average_gradients(tower_grads) + train_op = optimizer.apply_gradients(grads) + target_vars['train_op'] = train_op + + config = tf.ConfigProto() + + if hvd.size() > 1: + config.gpu_options.visible_device_list = str(hvd.local_rank()) + + sess = tf.Session(config=config) + + saver = loader = tf.train.Saver( + max_to_keep=30, keep_checkpoint_every_n_hours=6) + + total_parameters = 0 + for variable in tf.trainable_variables(): + # shape is an array of tf.Dimension + shape = variable.get_shape() + variable_parameters = 1 + for dim in shape: + variable_parameters *= dim.value + total_parameters += variable_parameters + print("Model has a total of {} parameters".format(total_parameters)) + + sess.run(tf.global_variables_initializer()) + + resume_itr = 0 + + if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: + model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) + resume_itr = FLAGS.resume_iter + # saver.restore(sess, model_file) + optimistic_restore(sess, model_file) + + sess.run(hvd.broadcast_global_variables(0)) + print("Initializing variables...") + + print("Start broadcast") + print("End broadcast") + + if FLAGS.train: + train(target_vars, saver, sess, + logger, data_loader, resume_itr, + logdir) + + test(target_vars, saver, sess, logger, data_loader) + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..c53c73f --- /dev/null +++ b/utils.py @@ -0,0 +1,1107 @@ +""" Utility functions. """ +import numpy as np +import os +import random +import tensorflow as tf +import warnings + +from tensorflow.contrib.layers.python import layers as tf_layers +from tensorflow.python.platform import flags +from tensorflow.contrib.framework import sort + +FLAGS = flags.FLAGS +flags.DEFINE_integer('spec_iter', 1, 'Number of iterations to normalize spectrum of matrix') +flags.DEFINE_float('spec_norm_val', 1.0, 'Desired norm of matrices') +flags.DEFINE_bool('downsample', False, 'Wheter to do average pool downsampling') +flags.DEFINE_bool('spec_eval', False, 'Set to true to prevent spectral updates') + + +def get_median(v): + v = tf.reshape(v, [-1]) + m = tf.shape(v)[0] // 2 + return tf.nn.top_k(v, m)[m - 1] + + +def set_seed(seed): + import torch + import numpy + import random + + torch.manual_seed(seed) + numpy.random.seed(seed) + random.seed(seed) + tf.set_random_seed(seed) + + +def swish(inp): + return inp * tf.nn.sigmoid(inp) + + +class ReplayBuffer(object): + def __init__(self, size): + """Create Replay buffer. + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + """ + self._storage = [] + self._maxsize = size + self._next_idx = 0 + + def __len__(self): + return len(self._storage) + + def add(self, ims): + batch_size = ims.shape[0] + if self._next_idx >= len(self._storage): + self._storage.extend(list(ims)) + else: + if batch_size + self._next_idx < self._maxsize: + self._storage[self._next_idx:self._next_idx + + batch_size] = list(ims) + else: + split_idx = self._maxsize - self._next_idx + self._storage[self._next_idx:] = list(ims)[:split_idx] + self._storage[:batch_size - split_idx] = list(ims)[split_idx:] + self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize + + def _encode_sample(self, idxes): + ims = [] + for i in idxes: + ims.append(self._storage[i]) + return np.array(ims) + + def sample(self, batch_size): + """Sample a batch of experiences. + Parameters + ---------- + batch_size: int + How many transitions to sample. + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + """ + idxes = [random.randint(0, len(self._storage) - 1) + for _ in range(batch_size)] + return self._encode_sample(idxes) + + +def get_weight( + name, + shape, + gain=np.sqrt(2), + use_wscale=False, + fan_in=None, + spec_norm=False, + zero=False, + fc=False): + if fan_in is None: + fan_in = np.prod(shape[:-1]) + std = gain / np.sqrt(fan_in) # He init + if use_wscale: + wscale = tf.constant(np.float32(std), name=name + 'wscale') + var = tf.get_variable( + name + 'weight', + shape=shape, + initializer=tf.initializers.random_normal()) * wscale + elif spec_norm: + if zero: + var = tf.get_variable( + shape=shape, + name=name + 'weight', + initializer=tf.initializers.random_normal( + stddev=1e-10)) + var = spectral_normed_weight(var, name, lower_bound=True, fc=fc) + else: + var = tf.get_variable( + name + 'weight', + shape=shape, + initializer=tf.initializers.random_normal()) + var = spectral_normed_weight(var, name, fc=fc) + else: + if zero: + var = tf.get_variable( + name + 'weight', + shape=shape, + initializer=tf.initializers.zero()) + else: + var = tf.get_variable( + name + 'weight', + shape=shape, + initializer=tf.contrib.layers.xavier_initializer( + dtype=tf.float32)) + + return var + + +def pixel_norm(x, epsilon=1e-8): + with tf.variable_scope('PixelNorm'): + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), + axis=[1, 2], keepdims=True) + epsilon) + + +# helper +def get_images(paths, labels, nb_samples=None, shuffle=True): + if nb_samples is not None: + def sampler(x): return random.sample(x, nb_samples) + else: + def sampler(x): return x + images = [(i, os.path.join(path, image)) + for i, path in zip(labels, paths) + for image in sampler(os.listdir(path))] + if shuffle: + random.shuffle(images) + return images + + +def optimistic_restore(session, save_file, v_prefix=None): + reader = tf.train.NewCheckpointReader(save_file) + saved_shapes = reader.get_variable_to_shape_map() + + var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.get_collection( + tf.GraphKeys.GLOBAL_VARIABLES) if var.name.split(':')[0] in saved_shapes]) + restore_vars = [] + with tf.variable_scope('', reuse=True): + for var_name, saved_var_name in var_names: + try: + curr_var = tf.get_variable(saved_var_name) + except Exception as e: + print(e) + continue + var_shape = curr_var.get_shape().as_list() + if var_shape == saved_shapes[saved_var_name]: + restore_vars.append(curr_var) + else: + print(var_name) + print(var_shape, saved_shapes[saved_var_name]) + + saver = tf.train.Saver(restore_vars) + saver.restore(session, save_file) + + +def optimistic_remap_restore(session, save_file, v_prefix): + reader = tf.train.NewCheckpointReader(save_file) + saved_shapes = reader.get_variable_to_shape_map() + + vars_list = tf.get_collection( + tf.GraphKeys.GLOBAL_VARIABLES, + scope='context_{}'.format(v_prefix)) + var_names = sorted([(var.name.split(':')[0], var) for var in vars_list if ( + (var.name.split(':')[0]).replace('context_{}'.format(v_prefix), 'context_0') in saved_shapes)]) + restore_vars = [] + + v_map = {} + with tf.variable_scope('', reuse=True): + for saved_var_name, curr_var in var_names: + var_shape = curr_var.get_shape().as_list() + saved_var_name = saved_var_name.replace( + 'context_{}'.format(v_prefix), 'context_0') + if var_shape == saved_shapes[saved_var_name]: + v_map[saved_var_name] = curr_var + else: + print(saved_var_name) + print(var_shape, saved_shapes[saved_var_name]) + + saver = tf.train.Saver(v_map) + saver.restore(session, save_file) + + +def remap_restore(session, save_file, i): + reader = tf.train.NewCheckpointReader(save_file) + saved_shapes = reader.get_variable_to_shape_map() + var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables() + if var.name.split(':')[0] in saved_shapes]) + restore_vars = [] + with tf.variable_scope('', reuse=True): + for var_name, saved_var_name in var_names: + try: + curr_var = tf.get_variable(saved_var_name) + except Exception as e: + print(e) + continue + var_shape = curr_var.get_shape().as_list() + if var_shape == saved_shapes[saved_var_name]: + restore_vars.append(curr_var) + print(restore_vars) + saver = tf.train.Saver(restore_vars) + saver.restore(session, save_file) + + +# Network weight initializers +def init_conv_weight( + weights, + scope, + k, + c_in, + c_out, + spec_norm=True, + zero=False, + scale=1.0, + classes=1): + + if spec_norm: + spec_norm = FLAGS.spec_norm + + conv_weights = {} + with tf.variable_scope(scope): + if zero: + conv_weights['c'] = get_weight( + 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True) + else: + conv_weights['c'] = get_weight( + 'c', [k, k, c_in, c_out], spec_norm=spec_norm) + + conv_weights['b'] = tf.get_variable( + shape=[c_out], name='b', initializer=tf.initializers.zeros()) + + if classes != 1: + conv_weights['g'] = tf.get_variable( + shape=[ + classes, + c_out], + name='g', + initializer=tf.initializers.ones()) + conv_weights['gb'] = tf.get_variable( + shape=[ + classes, + c_in], + name='gb', + initializer=tf.initializers.zeros()) + else: + conv_weights['g'] = tf.get_variable( + shape=[c_out], name='g', initializer=tf.initializers.ones()) + conv_weights['gb'] = tf.get_variable( + shape=[c_in], name='gb', initializer=tf.initializers.zeros()) + + conv_weights['cb'] = tf.get_variable( + shape=[c_in], name='cb', initializer=tf.initializers.zeros()) + + weights[scope] = conv_weights + + +def init_convt_weight( + weights, + scope, + k, + c_in, + c_out, + spec_norm=True, + zero=False, + scale=1.0, + classes=1): + + if spec_norm: + spec_norm = FLAGS.spec_norm + + conv_weights = {} + with tf.variable_scope(scope): + if zero: + conv_weights['c'] = get_weight( + 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True) + else: + conv_weights['c'] = get_weight( + 'c', [k, k, c_in, c_out], spec_norm=spec_norm) + + conv_weights['b'] = tf.get_variable( + shape=[c_in], name='b', initializer=tf.initializers.zeros()) + + if classes != 1: + conv_weights['g'] = tf.get_variable( + shape=[ + classes, + c_in], + name='g', + initializer=tf.initializers.ones()) + conv_weights['gb'] = tf.get_variable( + shape=[ + classes, + c_out], + name='gb', + initializer=tf.initializers.zeros()) + else: + conv_weights['g'] = tf.get_variable( + shape=[c_in], name='g', initializer=tf.initializers.ones()) + conv_weights['gb'] = tf.get_variable( + shape=[c_out], name='gb', initializer=tf.initializers.zeros()) + + conv_weights['cb'] = tf.get_variable( + shape=[c_in], name='cb', initializer=tf.initializers.zeros()) + + weights[scope] = conv_weights + + +def init_attention_weight( + weights, + scope, + c_in, + k, + trainable_gamma=True, + spec_norm=True): + + if spec_norm: + spec_norm = FLAGS.spec_norm + + atten_weights = {} + with tf.variable_scope(scope): + atten_weights['q'] = get_weight( + 'atten_q', [1, 1, c_in, k], spec_norm=spec_norm) + atten_weights['q_b'] = tf.get_variable( + shape=[k], name='atten_q_b1', initializer=tf.initializers.zeros()) + atten_weights['k'] = get_weight( + 'atten_k', [1, 1, c_in, k], spec_norm=spec_norm) + atten_weights['k_b'] = tf.get_variable( + shape=[k], name='atten_k_b1', initializer=tf.initializers.zeros()) + atten_weights['v'] = get_weight( + 'atten_v', [1, 1, c_in, c_in], spec_norm=spec_norm) + atten_weights['v_b'] = tf.get_variable( + shape=[c_in], name='atten_v_b1', initializer=tf.initializers.zeros()) + atten_weights['gamma'] = tf.get_variable( + shape=[1], name='gamma', initializer=tf.initializers.zeros()) + + weights[scope] = atten_weights + + +def init_fc_weight(weights, scope, c_in, c_out, spec_norm=True): + fc_weights = {} + + if spec_norm: + spec_norm = FLAGS.spec_norm + + with tf.variable_scope(scope): + fc_weights['w'] = get_weight( + 'w', [c_in, c_out], spec_norm=spec_norm, fc=True) + fc_weights['b'] = tf.get_variable( + shape=[c_out], name='b', initializer=tf.initializers.zeros()) + + weights[scope] = fc_weights + + +def init_res_weight( + weights, + scope, + k, + c_in, + c_out, + hidden_dim=None, + spec_norm=True, + res_scale=1.0, + classes=1): + + if not hidden_dim: + hidden_dim = c_in + + if spec_norm: + spec_norm = FLAGS.spec_norm + + init_conv_weight( + weights, + scope + + '_res_c1', + k, + c_in, + c_out, + spec_norm=spec_norm, + scale=res_scale, + classes=classes) + init_conv_weight( + weights, + scope + '_res_c2', + k, + c_out, + c_out, + spec_norm=spec_norm, + zero=True, + scale=res_scale, + classes=classes) + + if c_in != c_out: + init_conv_weight( + weights, + scope + + '_res_adaptive', + k, + c_in, + c_out, + spec_norm=spec_norm, + scale=res_scale, + classes=classes) + +# Network forward helpers + + +def smart_conv_block(inp, weights, reuse, scope, use_stride=True, **kwargs): + weights = weights[scope] + return conv_block( + inp, + weights['c'], + weights['b'], + reuse, + scope, + scale=weights['g'], + bias=weights['gb'], + class_bias=weights['cb'], + use_stride=use_stride, + **kwargs) + + +def smart_convt_block( + inp, + weights, + reuse, + scope, + output_dim, + upsample=True, + label=None): + weights = weights[scope] + + cweight = weights['c'] + bweight = weights['b'] + scale = weights['g'] + bias = weights['gb'] + class_bias = weights['cb'] + + if upsample: + stride = [1, 2, 2, 1] + else: + stride = [1, 1, 1, 1] + + if label is not None: + bias_batch = tf.matmul(label, bias) + batch = tf.shape(bias_batch)[0] + dim = tf.shape(bias_batch)[1] + bias = tf.reshape(bias_batch, (batch, 1, 1, dim)) + + inp = inp + bias + + shape = cweight.get_shape() + conv_output = tf.nn.conv2d_transpose(inp, + cweight, + [tf.shape(inp)[0], + output_dim, + output_dim, + cweight.get_shape().as_list()[-2]], + stride, + 'SAME') + + if label is not None: + scale_batch = tf.matmul(label, scale) + class_bias + batch = tf.shape(scale_batch)[0] + dim = tf.shape(scale_batch)[1] + scale = tf.reshape(scale_batch, (batch, 1, 1, dim)) + + conv_output = conv_output * scale + + conv_output = tf.nn.leaky_relu(conv_output) + + return conv_output + + +def smart_res_block( + inp, + weights, + reuse, + scope, + downsample=True, + adaptive=True, + stop_batch=False, + upsample=False, + label=None, + act=tf.nn.leaky_relu, + dropout=False, + train=False, + **kwargs): + gn1 = weights[scope + '_res_c1'] + gn2 = weights[scope + '_res_c2'] + c1 = smart_conv_block( + inp, + weights, + reuse, + scope + '_res_c1', + use_stride=False, + activation=None, + extra_bias=True, + label=label, + **kwargs) + + if dropout: + c1 = tf.layers.dropout(c1, rate=0.5, training=train) + + c1 = act(c1) + c2 = smart_conv_block( + c1, + weights, + reuse, + scope + '_res_c2', + use_stride=False, + activation=None, + use_scale=True, + extra_bias=True, + label=label, + **kwargs) + + if adaptive: + c_bypass = smart_conv_block( + inp, + weights, + reuse, + scope + + '_res_adaptive', + use_stride=False, + activation=None, + **kwargs) + else: + c_bypass = inp + + res = c2 + c_bypass + + if upsample: + res_shape = tf.shape(res) + res_shape_list = res.get_shape() + res = tf.image.resize_nearest_neighbor( + res, [2 * res_shape_list[1], 2 * res_shape_list[2]]) + elif downsample: + res = tf.nn.avg_pool(res, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID') + + res = act(res) + + return res + + +def smart_res_block_optim(inp, weights, reuse, scope, **kwargs): + c1 = smart_conv_block( + inp, + weights, + reuse, + scope + '_res_c1', + use_stride=False, + activation=None, + **kwargs) + c1 = tf.nn.leaky_relu(c1) + c2 = smart_conv_block( + c1, + weights, + reuse, + scope + '_res_c2', + use_stride=False, + activation=None, + **kwargs) + + inp = tf.nn.avg_pool(inp, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID') + c_bypass = smart_conv_block( + inp, + weights, + reuse, + scope + + '_res_adaptive', + use_stride=False, + activation=None, + **kwargs) + c2 = tf.nn.avg_pool(c2, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID') + + res = c2 + c_bypass + + return c2 + + +def groupsort(k=4): + def sortact(inp): + old_shape = tf.shape(inp) + inp = sort(tf.reshape(inp, (-1, 4))) + inp = tf.reshape(inp, old_shape) + return inp + return sortact + + +def smart_atten_block(inp, weights, reuse, scope, **kwargs): + w = weights[scope] + return attention( + inp, + w['q'], + w['q_b'], + w['k'], + w['k_b'], + w['v'], + w['v_b'], + w['gamma'], + reuse, + scope, + **kwargs) + + +def smart_fc_block(inp, weights, reuse, scope, use_bias=True): + weights = weights[scope] + output = tf.matmul(inp, weights['w']) + + if use_bias: + output = output + weights['b'] + + return output + + +# Network helpers +def conv_block( + inp, + cweight, + bweight, + reuse, + scope, + use_stride=True, + activation=tf.nn.leaky_relu, + pn=False, + bn=False, + gn=False, + ln=False, + scale=None, + bias=None, + class_bias=None, + use_bias=False, + downsample=False, + stop_batch=False, + use_scale=False, + extra_bias=False, + average=False, + label=None): + """ Perform, conv, batch norm, nonlinearity, and max pool """ + stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1] + _, h, w, _ = inp.get_shape() + + if FLAGS.downsample: + stride = no_stride + + if not use_bias: + bweight = 0 + + if extra_bias: + if label is not None: + if len(bias.get_shape()) == 1: + bias = tf.reshape(bias, (1, -1)) + bias_batch = tf.matmul(label, bias) + batch = tf.shape(bias_batch)[0] + dim = tf.shape(bias_batch)[1] + bias = tf.reshape(bias_batch, (batch, 1, 1, dim)) + + inp = inp + bias + + if not use_stride: + conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + else: + conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + + if use_scale: + if label is not None: + if len(scale.get_shape()) == 1: + scale = tf.reshape(scale, (1, -1)) + scale_batch = tf.matmul(label, scale) + class_bias + batch = tf.shape(scale_batch)[0] + dim = tf.shape(scale_batch)[1] + scale = tf.reshape(scale_batch, (batch, 1, 1, dim)) + + conv_output = conv_output * scale + + if use_bias: + conv_output = conv_output + bweight + + if activation is not None: + conv_output = activation(conv_output) + + if bn: + conv_output = batch_norm(conv_output, scale, bias) + if pn: + conv_output = pixel_norm(conv_output) + if gn: + conv_output = group_norm( + conv_output, scale, bias, stop_batch=stop_batch) + if ln: + conv_output = layer_norm(conv_output, scale, bias) + + if FLAGS.downsample and use_stride: + conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2) + + return conv_output + + +def conv_block_1d( + inp, + cweight, + bweight, + reuse, + scope, + activation=tf.nn.leaky_relu): + """ Perform, conv, batch norm, nonlinearity, and max pool """ + stride = 1 + + conv_output = tf.nn.conv1d(inp, cweight, stride, 'SAME') + bweight + + if activation is not None: + conv_output = activation(conv_output) + + return conv_output + + +def conv_block_3d( + inp, + cweight, + bweight, + reuse, + scope, + use_stride=True, + activation=tf.nn.leaky_relu, + pn=False, + bn=False, + gn=False, + ln=False, + scale=None, + bias=None, + use_bias=False): + """ Perform, conv, batch norm, nonlinearity, and max pool """ + stride, no_stride = [1, 1, 2, 2, 1], [1, 1, 1, 1, 1] + _, d, h, w, _ = inp.get_shape() + + if not use_bias: + bweight = 0 + + if not use_stride: + conv_output = tf.nn.conv3d(inp, cweight, no_stride, 'SAME') + bweight + else: + conv_output = tf.nn.conv3d(inp, cweight, stride, 'SAME') + bweight + + if activation is not None: + conv_output = activation(conv_output, alpha=0.1) + + if bn: + conv_output = batch_norm(conv_output, scale, bias) + if pn: + conv_output = pixel_norm(conv_output) + if gn: + conv_output = group_norm(conv_output, scale, bias) + if ln: + conv_output = layer_norm(conv_output, scale, bias) + + if FLAGS.downsample and use_stride: + conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2) + + return conv_output + + +def group_norm(inp, scale, bias, g=32, eps=1e-6, stop_batch=False): + """Applies group normalization assuming nhwc format""" + n, h, w, c = inp.shape + inp = tf.reshape(inp, (tf.shape(inp)[0], h, w, c // g, g)) + + mean, var = tf.nn.moments(inp, [1, 2, 4], keep_dims=True) + gain = tf.rsqrt(var + eps) + + # if stop_batch: + # gain = tf.stop_gradient(gain) + + output = gain * (inp - mean) + output = tf.reshape(output, (tf.shape(inp)[0], h, w, c)) + + if scale is not None: + output = output * scale + + if bias is not None: + output = output + bias + + return output + + +def layer_norm(inp, scale, bias, eps=1e-6): + """Applies group normalization assuming nhwc format""" + n, h, w, c = inp.shape + + mean, var = tf.nn.moments(inp, [1, 2, 3], keep_dims=True) + gain = tf.rsqrt(var + eps) + output = gain * (inp - mean) + + if scale is not None: + output = output * scale + + if bias is not None: + output = output + bias + + return output + + +def conv_cond_concat(x, y): + """Concatenate conditioning vector on feature map axis.""" + x_shapes = tf.shape(x) + y_shapes = tf.shape(y) + + return tf.concat( + [x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]) / 10.], 3) + + +def attention( + inp, + q, + q_b, + k, + k_b, + v, + v_b, + gamma, + reuse, + scope, + stop_at_grad=False, + seperate=False, + scale=False, + train=False, + dropout=0.0): + conv_q = conv_block( + inp, + q, + q_b, + reuse=reuse, + scope=scope, + use_stride=False, + activation=None, + use_bias=True, + pn=False, + bn=False, + gn=False) + conv_k = conv_block( + inp, + k, + k_b, + reuse=reuse, + scope=scope, + use_stride=False, + activation=None, + use_bias=True, + pn=False, + bn=False, + gn=False) + + conv_v = conv_block( + inp, + v, + v_b, + reuse=reuse, + scope=scope, + use_stride=False, + pn=False, + bn=False, + gn=False) + + c_num = float(conv_q.get_shape().as_list()[-1]) + s = tf.matmul(hw_flatten(conv_q), hw_flatten(conv_k), transpose_b=True) + + if scale: + s = s / (c_num) ** 0.5 + + if train: + s = tf.nn.dropout(s, 0.9) + + beta = tf.nn.softmax(s, axis=-1) + o = tf.matmul(beta, hw_flatten(conv_v)) + o = tf.reshape(o, shape=tf.shape(inp)) + inp = inp + gamma * o + + if not seperate: + return inp + else: + return gamma * o + + +def attention_2d( + inp, + q, + q_b, + k, + k_b, + v, + v_b, + reuse, + scope, + stop_at_grad=False, + seperate=False, + scale=False): + inp_shape = tf.shape(inp) + inp_compact = tf.reshape( + inp, + (inp_shape[0] * + FLAGS.input_objects * + inp_shape[1], + inp.shape[3])) + f_q = tf.matmul(inp_compact, q) + q_b + f_k = tf.matmul(inp_compact, k) + k_b + f_v = tf.nn.leaky_relu(tf.matmul(inp_compact, v) + v_b) + + f_q = tf.reshape(f_q, + (inp_shape[0], + inp_shape[1], + inp_shape[2], + tf.shape(f_q)[-1])) + f_k = tf.reshape(f_k, + (inp_shape[0], + inp_shape[1], + inp_shape[2], + tf.shape(f_k)[-1])) + f_v = tf.reshape( + f_v, + (inp_shape[0], + inp_shape[1], + inp_shape[2], + inp_shape[3])) + + s = tf.matmul(f_k, f_q, transpose_b=True) + c_num = (32**0.5) + + if scale: + s = s / c_num + + beta = tf.nn.softmax(s, axis=-1) + + o = tf.reshape(tf.matmul(beta, f_v), inp_shape) + inp + + return o + + +def hw_flatten(x): + shape = tf.shape(x) + return tf.reshape(x, [tf.shape(x)[0], -1, shape[-1]]) + + +def batch_norm(inp, scale, bias, eps=0.01): + mean, var = tf.nn.moments(inp, [0]) + output = tf.nn.batch_normalization(inp, mean, var, bias, scale, eps) + return output + + +def normalize(inp, activation, reuse, scope): + if FLAGS.norm == 'batch_norm': + return tf_layers.batch_norm( + inp, + activation_fn=activation, + reuse=reuse, + scope=scope) + elif FLAGS.norm == 'layer_norm': + return tf_layers.layer_norm( + inp, + activation_fn=activation, + reuse=reuse, + scope=scope) + elif FLAGS.norm == 'None': + if activation is not None: + return activation(inp) + else: + return inp + +# Loss functions + + +def mse(pred, label): + pred = tf.reshape(pred, [-1]) + label = tf.reshape(label, [-1]) + return tf.reduce_mean(tf.square(pred - label)) + + +NO_OPS = 'NO_OPS' + + +def _l2normalize(v, eps=1e-12): + return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) + + +def spectral_normed_weight(w, name, lower_bound=False, iteration=1, fc=False): + if fc: + iteration = 2 + + w_shape = w.shape.as_list() + w = tf.reshape(w, [-1, w_shape[-1]]) + + iteration = FLAGS.spec_iter + sigma_new = FLAGS.spec_norm_val + + u = tf.get_variable(name + "_u", + [1, + w_shape[-1]], + initializer=tf.random_normal_initializer(), + trainable=False) + + u_hat = u + v_hat = None + for i in range(iteration): + """ + power iteration + Usually iteration = 1 will be enough + """ + v_ = tf.matmul(u_hat, tf.transpose(w)) + v_hat = tf.nn.l2_normalize(v_) + + u_ = tf.matmul(v_hat, w) + u_hat = tf.nn.l2_normalize(u_) + + u_hat = tf.stop_gradient(u_hat) + v_hat = tf.stop_gradient(v_hat) + + sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) + + if FLAGS.spec_eval: + dep = [] + else: + dep = [u.assign(u_hat)] + + with tf.control_dependencies(dep): + if lower_bound: + sigma = sigma + 1e-6 + w_norm = w / sigma * tf.minimum(sigma, 1) * sigma_new + else: + w_norm = w / sigma * sigma_new + + w_norm = tf.reshape(w_norm, w_shape) + + return w_norm + + +def average_gradients(tower_grads): + """Calculate the average gradient for each shared variable across all towers. + Note that this function provides a synchronization point across all towers. + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over individual gradients. The inner list is over the gradient + calculation for each tower. + Returns: + List of pairs of (gradient, variable) where the gradient has been averaged + across all towers. + """ + average_grads = [] + for grad_and_vars in zip(*tower_grads): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, v in grad_and_vars: + if g is not None: + # Add 0 dimension to the gradients to represent the tower. + expanded_g = tf.expand_dims(g, 0) + + # Append on a 'tower' dimension which we will average over + # below. + grads.append(expanded_g) + else: + print(g, v) + + # Average over the 'tower' dimension. + grad = tf.concat(axis=0, values=grads) + grad = tf.reduce_mean(grad, 0) + + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + return average_grads