Add adapted code from OpenAI ebs

This commit is contained in:
Steinkirch 2020-05-10 22:32:26 -07:00
parent 42ee40d37d
commit ec2ff8be87
17 changed files with 6795 additions and 0 deletions

5
.gitignore vendored
View File

@ -127,3 +127,8 @@ dmypy.json
# Pyre type checker
.pyre/
# Custom
sandbox_cachedir/
cachedir
results

124
README.md Normal file
View File

@ -0,0 +1,124 @@
# Implicit Generation and Generalization in Energy Based Models
Code for [Implicit Generation and Generalization in Energy Based Models](https://arxiv.org/pdf/1903.08689.pdf). Blog post can be found [here](https://openai.com/blog/energy-based-models/) and website with pretrained models can be found [here](https://sites.google.com/view/igebm/home).
## Requirements
To install the prerequisites for the project run
```
pip install -r requirements.txt
mkdir sandbox_cachedir
```
Download all [pretrained models](https://sites.google.com/view/igebm/home) and unzip into the folder cachedir.
## Download Datasets
For MNIST and CIFAR-10 datasets, the code will directly download the data.
For ImageNet 128x128 dataset, download the TFRecords of the Imagenet dataset by running the following command
```
for i in $(seq -f "%05g" 0 1023)
do
wget https://storage.googleapis.com/ebm_demo/data/imagenet/train-$i-of-01024
done
for i in $(seq -f "%05g" 0 127)
do
wget https://storage.googleapis.com/ebm_demo/data/imagenet/validation-$i-of-00128
done
wget https://storage.googleapis.com/ebm_demo/data/imagenet/index.json
```
For Imagenet 32x32 dataset, download the Imagenet 32x32 dataset and unzip by running the following command
```
wget https://storage.googleapis.com/ebm_demo/data/imagenet32/Imagenet32_train.zip
wget https://storage.googleapis.com/ebm_demo/data/imagenet32/Imagenet32_val.zip
```
For dSprites dataset, download the dataset by running
```
wget https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true
```
## Training
To train on different datasets:
For CIFAR-10 Unconditional
```
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
```
For CIFAR-10 Conditional
```
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
```
For ImageNet 32x32 Conditional
```
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
```
For ImageNet 128x128 Conditional
```
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
```
All code supports horovod execution, so model training can be increased substantially by using multiple different workers by running each command.
```
mpiexec -n <worker_num> <command>
```
## Demo
The imagenet_demo.py file contains code to experiments with EBMs on conditional ImageNet 128x128. To generate a gif on sampling, you can run the command:
```
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
```
The ebm_sandbox.py file contains several different tasks that can be used to evaluate EBMs, which are defined by different settings of task flag in the file. For example, to visualize cross class mappings in CIFAR-10, you can run:
```
python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resume_iter=74700
```
## Generalization
To test generalization to out of distribution classification for SVHN (with similar commands for other datasets)
```
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
```
To test classification on CIFAR-10 using a conditional model under either L2 or Li perturbations
```
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd steps> --num_steps=10 --lival=<li bound value> --wider_model
```
## Concept Combination
To train EBMs on conditional dSprites dataset, you can train each model seperately on each conditioned latent in cond_pos, cond_rot, cond_shape, cond_scale, with an example command given below.
```
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
```
Once models are trained, they can be sampled from jointly by running
```
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos> --exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot> --resume_pos=<resume_pos>
```

249
ais.py Normal file
View File

@ -0,0 +1,249 @@
import tensorflow as tf
import math
from hmc import hmc
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
from data import Cifar10, Mnist, DSprites
from scipy.misc import logsumexp
from scipy.misc import imsave
from utils import optimistic_restore
import os.path as osp
import numpy as np
from tqdm import tqdm
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
FLAGS = flags.FLAGS
label_default = np.eye(10)[0:1, :]
label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
def unscale_im(im):
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
def gauss_prob_log(x, prec=1.0):
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
return norm_constant_log + prob_density_log
def uniform_prob_log(x):
return tf.zeros(1)
def model_prob_log(x, e_func, weights, temp):
if FLAGS.cclass:
batch_size = tf.shape(x)[0]
label_tiled = tf.tile(label_default, (batch_size, 1))
e_raw = e_func.forward(x, weights, label=label_tiled)
else:
e_raw = e_func.forward(x, weights)
energy = tf.reduce_sum(e_raw, axis=[1])
return -temp * energy
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
if FLAGS.dataset == "gauss":
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
else:
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
elif FLAGS.dataset == 'mnist':
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
else:
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
return -norm_prob + oob_prob
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
if FLAGS.dataset == "2d":
x = tf.placeholder(tf.float32, shape=(None, 2))
elif FLAGS.dataset == "gauss":
x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim))
elif FLAGS.dataset == "mnist":
x = tf.placeholder(tf.float32, shape=(None, 28, 28))
else:
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
x_init = x
alpha_prev = tf.placeholder(tf.float32, shape=())
alpha_new = tf.placeholder(tf.float32, shape=())
approx_lr = tf.placeholder(tf.float32, shape=())
chain_weights = tf.zeros(batch_size)
# for i in range(1, prop_dist+1):
# print("processing loop {}".format(i))
# alpha_prev = (i-1) / prop_dist
# alpha_new = i / prop_dist
prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp)
prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
chain_weights = -prob_log_new_neg + prob_log_old_neg
# chain_weights = tf.Print(chain_weights, [chain_weights])
# Sample new x using HMC
def unorm_prob(x):
return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
for j in range(1):
x = hmc(x, approx_lr, hmc_step, unorm_prob)
return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr
def main():
# Initialize dataset
if FLAGS.dataset == 'cifar10':
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3
dim_input = 32 * 32 * 3
elif FLAGS.dataset == 'imagenet':
dataset = ImagenetClass()
channel_num = 3
dim_input = 64 * 64 * 3
elif FLAGS.dataset == 'mnist':
dataset = Mnist(train=False, rescale=FLAGS.rescale)
channel_num = 1
dim_input = 28 * 28 * 1
elif FLAGS.dataset == 'dsprites':
dataset = DSprites()
channel_num = 1
dim_input = 64 * 64 * 1
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
dataset = Box2D()
dim_output = 1
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
if FLAGS.dataset == 'mnist':
model = MnistNet(num_channels=channel_num)
elif FLAGS.dataset == 'cifar10':
if FLAGS.large_model:
model = ResNet32Large(num_filters=128)
elif FLAGS.wider_model:
model = ResNet32Wider(num_filters=192)
else:
model = ResNet32(num_channels=channel_num, num_filters=128)
elif FLAGS.dataset == 'dsprites':
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
weights = model.construct_weights('context_{}'.format(0))
config = tf.ConfigProto()
sess = tf.Session(config=config)
saver = loader = tf.train.Saver(max_to_keep=10)
sess.run(tf.global_variables_initializer())
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
if FLAGS.resume_iter != "-1":
optimistic_restore(sess, model_file)
else:
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
# saver.restore(sess, model_file)
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
print("Finished constructing ancestral sample ...................")
if FLAGS.dataset != "gauss":
comb_weights_cum = []
batch_size = tf.shape(x_init)[0]
label_tiled = tf.tile(label_default, (batch_size, 1))
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
e_pos_list = []
for data_corrupt, data, label_gt in tqdm(data_loader):
e_pos = sess.run([e_compute], {x_init: data})[0]
e_pos_list.extend(list(e_pos))
print(len(e_pos_list))
print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list))
if FLAGS.dataset == "2d":
alr = 0.0045
elif FLAGS.dataset == "gauss":
alr = 0.0085
elif FLAGS.dataset == "mnist":
alr = 0.0065
#90 alr = 0.0035
else:
# alr = 0.0125
if FLAGS.rescale == 8:
alr = 0.0085
else:
alr = 0.0045
#
for i in range(1):
tot_weight = 0
for j in tqdm(range(1, FLAGS.pdist+1)):
if j == 1:
if FLAGS.dataset == "cifar10":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
elif FLAGS.dataset == "gauss":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
elif FLAGS.dataset == "mnist":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
else:
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
alpha_prev = (j-1) / FLAGS.pdist
alpha_new = j / FLAGS.pdist
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
tot_weight = tot_weight + cweight
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
tot_weight = 0
for j in tqdm(range(FLAGS.pdist, 0, -1)):
alpha_new = (j-1) / FLAGS.pdist
alpha_prev = j / FLAGS.pdist
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
tot_weight = tot_weight - cweight
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
if __name__ == "__main__":
main()

236
custom_adam.py Normal file
View File

@ -0,0 +1,236 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adam for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
from tensorflow.python.util.tf_export import tf_export
import tensorflow as tf
@tf_export("train.AdamOptimizer")
class AdamOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
Initialization:
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
$$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
The default value of 1e-8 for epsilon might not be a good default in
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
formulation just before Section 2.1 of the Kingma and Ba paper rather than
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
hat" in the paper.
The sparse implementation of this algorithm (used when the gradient is an
IndexedSlices object, typically because of `tf.gather` or an embedding
lookup in the forward pass) does apply momentum to variable slices even if
they were not used in the forward pass (meaning they have a gradient equal
to zero). Momentum decay (beta1) is also applied to the entire momentum
accumulator. This means that the sparse behavior is equivalent to the dense
behavior (in contrast to some momentum implementations which ignore momentum
unless a variable slice was actually used).
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
`epsilon` can each be a callable that takes no arguments and returns the
actual value to use. This can be useful for changing these values across
different invocations of optimizer functions.
@end_compatibility
"""
super(AdamOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
# Tensor versions of the constructor arguments, created in _prepare().
self._lr_t = None
self._beta1_t = None
self._beta2_t = None
self._epsilon_t = None
# Created in SparseApply if needed.
self._updated_lr = None
def _get_beta_accumulators(self):
with ops.init_scope():
if context.executing_eagerly():
graph = None
else:
graph = ops.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))
def _create_slots(self, var_list):
# Create the beta1 and beta2 accumulators on the same device as the first
# variable. Sort the var_list to make sure this device is consistent across
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(initial_value=self._beta1,
name="beta1_power",
colocate_with=first_var)
self._create_non_slot_variable(initial_value=self._beta2,
name="beta2_power",
colocate_with=first_var)
# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)
def _prepare(self):
lr = self._call_if_callable(self._lr)
beta1 = self._call_if_callable(self._beta1)
beta2 = self._call_if_callable(self._beta2)
epsilon = self._call_if_callable(self._epsilon)
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
# Clip gradients by 3 std
return training_ops.apply_adam(
var, m, v,
math_ops.cast(beta1_power, var.dtype.base_dtype),
math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
grad, use_locking=self._use_locking).op
def _resource_apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(
var.handle, m.handle, v.handle,
math_ops.cast(beta1_power, grad.dtype.base_dtype),
math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad, use_locking=self._use_locking)
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values, var, grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking))
def _resource_scatter_add(self, x, i, v):
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(
x.handle, i, v)]):
return x.value()
def _resource_apply_sparse(self, grad, var, indices):
return self._apply_sparse_shared(
grad, var, indices, self._resource_scatter_add)
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
beta1_power, beta2_power = self._get_beta_accumulators()
with ops.colocate_with(beta1_power):
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t, use_locking=self._use_locking)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t, use_locking=self._use_locking)
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
name=name_scope)

546
data.py Normal file
View File

@ -0,0 +1,546 @@
from tensorflow.python.platform import flags
from tensorflow.contrib.data.python.ops import batching, threadpool
import tensorflow as tf
import json
from torch.utils.data import Dataset
import pickle
import os.path as osp
import os
import numpy as np
import time
from scipy.misc import imread, imresize
from skimage.color import rgb2grey
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
from torchvision import transforms
from imagenet_preprocessing import ImagenetPreprocessor
import torch
import torchvision
FLAGS = flags.FLAGS
ROOT_DIR = "./results"
# Dataset Options
flags.DEFINE_string('dsprites_path',
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
'path to dsprites characters')
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
# Data augmentation options
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
def cutout(mask_color=(0, 0, 0)):
mask_size_half = FLAGS.cutout_mask_size // 2
offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0
def _cutout(image):
image = np.asarray(image).copy()
if np.random.random() > FLAGS.cutout_prob:
return image
h, w = image.shape[:2]
if FLAGS.cutout_inside:
cxmin, cxmax = mask_size_half, w + offset - mask_size_half
cymin, cymax = mask_size_half, h + offset - mask_size_half
else:
cxmin, cxmax = 0, w + offset
cymin, cymax = 0, h + offset
cx = np.random.randint(cxmin, cxmax)
cy = np.random.randint(cymin, cymax)
xmin = cx - mask_size_half
ymin = cy - mask_size_half
xmax = xmin + FLAGS.cutout_mask_size
ymax = ymin + FLAGS.cutout_mask_size
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = min(w, xmax)
ymax = min(h, ymax)
image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
return image
return _cutout
class TFImagenetLoader(Dataset):
def __init__(self, split, batchsize, idx, num_workers, rescale=1):
IMAGENET_NUM_TRAIN_IMAGES = 1281167
IMAGENET_NUM_VAL_IMAGES = 50000
self.rescale = rescale
if split == "train":
im_length = IMAGENET_NUM_TRAIN_IMAGES
records_to_skip = im_length * idx // num_workers
records_to_read = im_length * (idx + 1) // num_workers - records_to_skip
else:
im_length = IMAGENET_NUM_VAL_IMAGES
self.curr_sample = 0
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
with open(index_path) as f:
metadata = json.load(f)
counts = metadata['record_counts']
if split == 'train':
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
result_records_to_skip = None
files = []
for filename in file_names:
records_in_file = counts[filename]
if records_to_skip >= records_in_file:
records_to_skip -= records_in_file
continue
elif records_to_read > 0:
if result_records_to_skip is None:
# Record the number to skip in the first file
result_records_to_skip = records_to_skip
files.append(filename)
records_to_read -= (records_in_file - records_to_skip)
records_to_skip = 0
else:
break
else:
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
ds = ds.apply(tf.data.TFRecordDataset)
ds = ds.take(im_length)
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
ds = ds.prefetch(buffer_size=2)
ds_iterator = ds.make_initializable_iterator()
labels, images = ds_iterator.get_next()
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
self.labels = labels
config = tf.ConfigProto(device_count = {'GPU': 0})
sess = tf.Session(config=config)
sess.run(ds_iterator.initializer)
self.im_length = im_length // batchsize
self.sess = sess
def __next__(self):
self.curr_sample += 1
sess = self.sess
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
label, im = sess.run([self.labels, self.images])
im = im * self.rescale
label = np.eye(1000)[label.squeeze() - 1]
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
return im_corrupt, im, label
def __iter__(self):
return self
def __len__(self):
return self.im_length
class CelebA(Dataset):
def __init__(self):
self.path = "/root/data/img_align_celeba"
self.ims = os.listdir(self.path)
self.ims = [osp.join(self.path, im) for im in self.ims]
def __len__(self):
return len(self.ims)
def __getitem__(self, index):
label = 1
if FLAGS.single:
index = 0
path = self.ims[index]
im = imread(path)
im = imresize(im, (32, 32))
image_size = 32
im = im / 255.
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0, 1, size=(image_size, image_size, 3))
return im_corrupt, im, label
class Cifar10(Dataset):
def __init__(
self,
train=True,
full=False,
augment=False,
noise=True,
rescale=1.0):
if augment:
transform_list = [
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
]
if FLAGS.cutout:
transform_list.append(cutout())
transform = transforms.Compose(transform_list)
else:
transform = transforms.ToTensor()
self.full = full
self.data = CIFAR10(
ROOT_DIR,
transform=transform,
train=train,
download=True)
self.test_data = CIFAR10(
ROOT_DIR,
transform=transform,
train=False,
download=True)
self.one_hot_map = np.eye(10)
self.noise = noise
self.rescale = rescale
def __len__(self):
if self.full:
return len(self.data) + len(self.test_data)
else:
return len(self.data)
def __getitem__(self, index):
if not FLAGS.single:
if self.full:
if index >= len(self.data):
im, label = self.test_data[index - len(self.data)]
else:
im, label = self.data[index]
else:
im, label = self.data[index]
else:
im, label = self.data[0]
im = np.transpose(im, (1, 2, 0)).numpy()
image_size = 32
label = self.one_hot_map[label]
im = im * 255 / 256
if self.noise:
im = im * self.rescale + \
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, self.rescale, (image_size, image_size, 3))
return im_corrupt, im, label
class Cifar100(Dataset):
def __init__(self, train=True, augment=False):
if augment:
transform_list = [
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
]
if FLAGS.cutout:
transform_list.append(cutout())
transform = transforms.Compose(transform_list)
else:
transform = transforms.ToTensor()
self.data = CIFAR100(
"/root/cifar100",
transform=transform,
train=train,
download=True)
self.one_hot_map = np.eye(100)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if not FLAGS.single:
im, label = self.data[index]
else:
im, label = self.data[0]
im = np.transpose(im, (1, 2, 0)).numpy()
image_size = 32
label = self.one_hot_map[label]
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
class Svhn(Dataset):
def __init__(self, train=True, augment=False):
transform = transforms.ToTensor()
self.data = SVHN("/root/svhn", transform=transform, download=True)
self.one_hot_map = np.eye(10)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if not FLAGS.single:
im, label = self.data[index]
else:
em, label = self.data[0]
im = np.transpose(im, (1, 2, 0)).numpy()
image_size = 32
label = self.one_hot_map[label]
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
class Mnist(Dataset):
def __init__(self, train=True, rescale=1.0):
self.data = MNIST(
"/root/mnist",
transform=transforms.ToTensor(),
download=True, train=train)
self.labels = np.eye(10)
self.rescale = rescale
def __len__(self):
return len(self.data)
def __getitem__(self, index):
im, label = self.data[index]
label = self.labels[label]
im = im.squeeze()
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
# im = im.numpy() / 2 + 0.2
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
im = im * self.rescale
image_size = 28
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
return im_corrupt, im, label
class DSprites(Dataset):
def __init__(
self,
cond_size=False,
cond_shape=False,
cond_pos=False,
cond_rot=False):
dat = np.load(FLAGS.dsprites_path)
if FLAGS.dshape_only:
l = dat['latents_values']
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
self.label = self.label[:, 1:2]
elif FLAGS.dpos_only:
l = dat['latents_values']
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 1] == 1) & (
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = self.label[:, 4:] + 0.5
elif FLAGS.dsize_only:
l = dat['latents_values']
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
self.label = (self.label[:, 2:3])
elif FLAGS.drot_only:
l = dat['latents_values']
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = (self.label[:, 3:4])
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
elif FLAGS.dsprites_restrict:
l = dat['latents_values']
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
self.data = dat['imgs'][mask]
self.label = dat['latents_values'][mask]
else:
self.data = dat['imgs']
self.label = dat['latents_values']
if cond_size:
self.label = self.label[:, 2:3]
elif cond_shape:
self.label = self.label[:, 1:2]
elif cond_pos:
self.label = self.label[:, 4:]
elif cond_rot:
self.label = self.label[:, 3:4]
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
else:
self.label = self.label[:, 1:2]
self.identity = np.eye(3)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
im = self.data[index]
image_size = 64
if not (
FLAGS.dpos_only or FLAGS.dsize_only) and (
not FLAGS.cond_size) and (
not FLAGS.cond_pos) and (
not FLAGS.cond_rot) and (
not FLAGS.drot_only):
label = self.identity[self.label[index].astype(
np.int32) - 1].squeeze()
else:
label = self.label[index]
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
return im_corrupt, im, label
class Imagenet(Dataset):
def __init__(self, train=True, augment=False):
if train:
for i in range(1, 11):
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'train_data_batch_{}'.format(i)),
'rb'))
if i == 1:
labels = f['labels']
data = f['data']
else:
labels.extend(f['labels'])
data = np.vstack((data, f['data']))
else:
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'val_data'),
'rb'))
labels = f['labels']
data = f['data']
self.labels = labels
self.data = data
self.one_hot_map = np.eye(1000)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
if not FLAGS.single:
im, label = self.data[index], self.labels[index]
else:
im, label = self.data[0], self.labels[0]
label -= 1
im = im.reshape((3, 32, 32)) / 255
im = im.transpose((1, 2, 0))
image_size = 32
label = self.one_hot_map[label]
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
class Textures(Dataset):
def __init__(self, train=True, augment=False):
self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images")
def __len__(self):
return 2 * len(self.dataset)
def __getitem__(self, index):
idx = index % (len(self.dataset))
im, label = self.dataset[idx]
im = np.array(im)[:32, :32] / 255
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
return im, im, label

698
ebm_combine.py Normal file
View File

@ -0,0 +1,698 @@
import tensorflow as tf
import math
from tqdm import tqdm
from hmc import hmc
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader, Dataset
from models import DspritesNet
from utils import optimistic_restore, ReplayBuffer
import os.path as osp
import numpy as np
from rl_algs.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
from custom_adam import AdamOptimizer
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
flags.DEFINE_bool('cclass', True, 'not cclass')
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
# Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
flags.DEFINE_integer('break_steps', 300, 'steps to break')
# Whether to train for gentest
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
FLAGS = flags.FLAGS
class DSpritesGen(Dataset):
def __init__(self, data, latents, frac=0.0):
l = latents
if FLAGS.joint_shape:
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
elif FLAGS.joint_rot:
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
else:
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
data_pos = data[mask_pos]
l_pos = l[mask_pos]
data_size = data[mask_size]
l_size = l[mask_size]
n = data_pos.shape[0] // data_size.shape[0]
data_pos = np.tile(data_pos, (n, 1, 1))
l_pos = np.tile(l_pos, (n, 1))
self.data = np.concatenate((data_pos, data_size), axis=0)
self.label = np.concatenate((l_pos, l_size), axis=0)
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
data_add = data[mask_neg]
l_add = l[mask_neg]
perm_idx = np.random.permutation(data_add.shape[0])
select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
data_add = data_add[select_idx]
l_add = l_add[select_idx]
self.data = np.concatenate((self.data, data_add), axis=0)
self.label = np.concatenate((self.label, l_add), axis=0)
self.identity = np.eye(3)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
im = self.data[index]
im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64)
if FLAGS.joint_shape:
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
elif FLAGS.joint_rot:
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
else:
label_size = self.label[index, 2:3]
label_pos = self.label[index, 4:]
return (im_corrupt, im, label_size, label_pos)
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
LABEL_SIZE = kvs['LABEL_SIZE']
model_size = kvs['model_size']
weight_size = kvs['weight_size']
x_mod = kvs['X_NOISE']
label_output = LABEL_SIZE
for i in range(FLAGS.num_steps):
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
label_grad = tf.gradients(e_noise, [label_output])[0]
# label_grad = tf.Print(label_grad, [label_grad])
label_output = label_output - 1.0 * label_grad
label_output = tf.clip_by_value(label_output, 0.5, 1.0)
diffs = []
for i in range(30):
s = i*FLAGS.batch_size
d = (i+1)*FLAGS.batch_size
data_i = data[s:d]
latent_i = latents[s:d]
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
size_pred = sess.run([label_output], feed_dict)[0]
size_gt = latent_i[:, 2:3]
diffs.append(np.abs(size_pred - size_gt).mean())
print(np.array(diffs).mean())
def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
# tf.reset_default_graph()
if FLAGS.joint_shape:
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5)
LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32)
else:
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
loss_sq = tf.reduce_mean(tf.square(X_out - X_label))
optimizer = AdamOptimizer(1e-3)
gvs = optimizer.compute_gradients(loss_sq)
gvs = [(k, v) for (k, v) in gvs if k is not None]
train_op = optimizer.apply_gradients(gvs)
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
datafull = data
itr = 0
saver = tf.train.Saver()
vs = optimizer.variables()
sess.run(tf.global_variables_initializer())
if FLAGS.train:
for _ in range(5):
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
data_corrupt = data_corrupt.numpy()
label_size, label_pos = label_size.numpy(), label_pos.numpy()
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
label_comb = np.concatenate([label_size, label_pos], axis=1)
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
output = [loss_sq, train_op]
loss, _ = sess.run(output, feed_dict=feed_dict)
itr += 1
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
l = latents
if FLAGS.joint_shape:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
else:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen]
losses = []
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
if FLAGS.joint_shape:
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
latent_pos = latent[:, 4:6]
latent = np.concatenate([latent_size, latent_pos], axis=1)
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
else:
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
# print(loss)
losses.append(loss)
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
data_try = data_gen[:10]
data_init = np.random.randn(10, 2*FLAGS.num_filters)
if FLAGS.joint_shape:
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
latent_pos = latents_gen[:10, 4:]
else:
latent_scale = latents_gen[:10, 2:3]
latent_pos = latents_gen[:10, 4:]
latent_tot = np.concatenate([latent_scale, latent_pos], axis=1)
feed_dict = {X_feed: data_init, LABEL: latent_tot}
x_output = sess.run([X_out], feed_dict=feed_dict)[0]
x_output = np.clip(x_output, 0, 1)
im_name = "size_scale_combine_genbaseline.png"
x_output_wrap = np.ones((10, 66, 66))
data_try_wrap = np.ones((10, 66, 66))
x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
return np.mean(losses)
def gentest(sess, kvs, data, latents, save_exp_dir):
X_NOISE = kvs['X_NOISE']
LABEL_SIZE = kvs['LABEL_SIZE']
LABEL_SHAPE = kvs['LABEL_SHAPE']
LABEL_POS = kvs['LABEL_POS']
LABEL_ROT = kvs['LABEL_ROT']
model_size = kvs['model_size']
model_shape = kvs['model_shape']
model_pos = kvs['model_pos']
model_rot = kvs['model_rot']
weight_size = kvs['weight_size']
weight_shape = kvs['weight_shape']
weight_pos = kvs['weight_pos']
weight_rot = kvs['weight_rot']
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
datafull = data
# Test combination of generalization where we use slices of both training
x_final = X_NOISE
x_mod_size = X_NOISE
x_mod_pos = X_NOISE
for i in range(FLAGS.num_steps):
# use cond_pos
energies = []
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
# energies.append(e_noise)
x_grad = tf.gradients(e_noise, [x_final])[0]
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
if FLAGS.joint_shape:
# use cond_shape
e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE)
elif FLAGS.joint_rot:
e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT)
else:
# use cond_size
e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE)
# energies.append(e_noise)
# energy_stack = tf.concat(energies, axis=1)
# energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1)
# energy_stack = tf.reduce_sum(energy_stack, axis=1)
x_grad = tf.gradients(e_noise, [x_mod_pos])[0]
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
# for x_mod_size
# use cond_size
# e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE)
# x_grad = tf.gradients(e_noise, [x_mod_size])[0]
# x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
# x_mod_size = x_mod_size - FLAGS.step_lr * x_grad
# x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
# # use cond_pos
# e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS)
# x_grad = tf.gradients(e_noise, [x_mod_size])[0]
# x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
# x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad)
# x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
x_mod = x_mod_pos
x_final = x_mod
if FLAGS.joint_shape:
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
elif FLAGS.joint_rot:
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
else:
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
neg_loss = coeff * (-1*energy_neg) / norm_constant
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
gvs = optimizer.compute_gradients(loss_total)
gvs = [(k, v) for (k, v) in gvs if k is not None]
train_op = optimizer.apply_gradients(gvs)
vs = optimizer.variables()
sess.run(tf.variables_initializer(vs))
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
x_off = tf.reduce_mean(tf.square(x_mod - X))
itr = 0
saver = tf.train.Saver()
x_mod = None
if FLAGS.train:
replay_buffer = ReplayBuffer(10000)
for _ in range(1):
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
data_corrupt = data_corrupt.numpy()[:, :, :]
data = data.numpy()[:, :, :]
if x_mod is not None:
replay_buffer.add(x_mod)
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.joint_shape:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
elif FLAGS.joint_rot:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
else:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
itr += 1
if itr % 10 == 0:
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
if itr == FLAGS.break_steps:
break
saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
l = latents
if FLAGS.joint_shape:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
elif FLAGS.joint_rot:
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
else:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen]
losses = []
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
x = 0.5 + np.random.randn(*dat.shape)
if FLAGS.joint_shape:
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
elif FLAGS.joint_rot:
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
else:
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
for i in range(2):
x = sess.run([x_final], feed_dict=feed_dict)[0]
feed_dict[X_NOISE] = x
loss = sess.run([x_off], feed_dict=feed_dict)[0]
losses.append(loss)
print("Mean MSE loss of {} ".format(np.mean(losses)))
data_try = data_gen[:10]
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
latent_scale = latents_gen[:10, 2:3]
latent_pos = latents_gen[:10, 4:]
if FLAGS.joint_shape:
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
elif FLAGS.joint_rot:
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
else:
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
if FLAGS.joint_shape:
im_name = "size_shape_combine_gentest.png"
else:
im_name = "size_scale_combine_gentest.png"
x_output_wrap = np.ones((10, 66, 66))
data_try_wrap = np.ones((10, 66, 66))
x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
X_NOISE = kvs['X_NOISE']
LABEL_SIZE = kvs['LABEL_SIZE']
LABEL_SHAPE = kvs['LABEL_SHAPE']
LABEL_POS = kvs['LABEL_POS']
LABEL_ROT = kvs['LABEL_ROT']
model_size = kvs['model_size']
model_shape = kvs['model_shape']
model_pos = kvs['model_pos']
model_rot = kvs['model_rot']
weight_size = kvs['weight_size']
weight_shape = kvs['weight_shape']
weight_pos = kvs['weight_pos']
weight_rot = kvs['weight_rot']
x_mod = X_NOISE
for i in range(FLAGS.num_steps):
if FLAGS.cond_scale:
e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE)
x_grad = tf.gradients(e_noise, [x_mod])[0]
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
x_mod = x_mod - FLAGS.step_lr * x_grad
x_mod = tf.clip_by_value(x_mod, 0, 1)
if FLAGS.cond_shape:
e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE)
x_grad = tf.gradients(e_noise, [x_mod])[0]
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
x_mod = x_mod - FLAGS.step_lr * x_grad
x_mod = tf.clip_by_value(x_mod, 0, 1)
if FLAGS.cond_pos:
e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS)
x_grad = tf.gradients(e_noise, [x_mod])[0]
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
x_mod = x_mod - FLAGS.step_lr * x_grad
x_mod = tf.clip_by_value(x_mod, 0, 1)
if FLAGS.cond_rot:
e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT)
x_grad = tf.gradients(e_noise, [x_mod])[0]
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
x_mod = x_mod - FLAGS.step_lr * x_grad
x_mod = tf.clip_by_value(x_mod, 0, 1)
print("Finished constructing loop {}".format(i))
x_final = x_mod
data_try = data[:10]
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
label_scale = latents[:10, 2:3]
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
label_rot = latents[:10, 3:4]
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
label_pos = latents[:10, 4:]
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
LABEL_ROT: label_rot}
x_out = sess.run([x_final], feed_dict)[0]
im_name = "im"
if FLAGS.cond_scale:
im_name += "_condscale"
if FLAGS.cond_shape:
im_name += "_condshape"
if FLAGS.cond_pos:
im_name += "_condpos"
if FLAGS.cond_rot:
im_name += "_condrot"
im_name += ".png"
x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66))
x_out_pad[:, 1:-1, 1:-1] = x_out
data_try_pad[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
def main():
data = np.load(FLAGS.dsprites_path)['imgs']
l = latents = np.load(FLAGS.dsprites_path)['latents_values']
np.random.seed(1)
idx = np.random.permutation(data.shape[0])
data = data[idx]
latents = latents[idx]
config = tf.ConfigProto()
sess = tf.Session(config=config)
# Model 1 will be conditioned on size
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
weight_size = model_size.construct_weights('context_0')
# Model 2 will be conditioned on shape
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
weight_shape = model_shape.construct_weights('context_1')
# Model 3 will be conditioned on position
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
weight_pos = model_pos.construct_weights('context_2')
# Model 4 will be conditioned on rotation
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
weight_rot = model_rot.construct_weights('context_3')
sess.run(tf.global_variables_initializer())
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
if FLAGS.cond_scale:
saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_size)
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
if FLAGS.cond_shape:
saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_shape)
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
saver = tf.train.Saver(v_map)
if FLAGS.cond_pos:
saver.restore(sess, save_path_pos)
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
saver = tf.train.Saver(v_map)
if FLAGS.cond_rot:
saver.restore(sess, save_path_rot)
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32)
LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
x_mod = X_NOISE
kvs = {}
kvs['X_NOISE'] = X_NOISE
kvs['LABEL_SIZE'] = LABEL_SIZE
kvs['LABEL_SHAPE'] = LABEL_SHAPE
kvs['LABEL_POS'] = LABEL_POS
kvs['LABEL_ROT'] = LABEL_ROT
kvs['model_size'] = model_size
kvs['model_shape'] = model_shape
kvs['model_pos'] = model_pos
kvs['model_rot'] = model_rot
kvs['weight_size'] = weight_size
kvs['weight_shape'] = weight_shape
kvs['weight_pos'] = weight_pos
kvs['weight_rot'] = weight_rot
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
if FLAGS.task == 'conceptcombine':
conceptcombine(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'labeldiscover':
labeldiscover(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'gentest':
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
gentest(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'genbaseline':
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
if FLAGS.plot_curve:
mse_losses = []
for frac in [i/10 for i in range(11)]:
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
mse_losses.append(mse_loss)
np.save("mse_baseline_comb.npy", mse_losses)
else:
genbaseline(sess, kvs, data, latents, save_exp_dir)
if __name__ == "__main__":
main()

981
ebm_sandbox.py Normal file
View File

@ -0,0 +1,981 @@
import tensorflow as tf
import math
from tqdm import tqdm
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
import torch
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, DspritesNet
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, DSprites
from utils import optimistic_restore, set_seed
import os.path as osp
import numpy as np
from baselines.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
import sklearn.metrics as sk
from baselines.common.tf_util import initialize
from scipy.linalg import eig
import matplotlib.pyplot as plt
# set_seed(1)
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites')
flags.DEFINE_string('logdir', 'sandbox_cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('task', 'label', 'using conditional energy based models for classification'
'anticorrupt: restore salt and pepper noise),'
' boxcorrupt: restore empty portion of image'
'or crossclass: change images from one class to another'
'or cycleclass: view image change across a label'
'or nearestneighbor which returns the nearest images in the test set'
'or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)'
'or mixenergy to evaluate out of distribution generalization compared to other datasets')
flags.DEFINE_bool('hessian', True, 'Whether to use the hessian or the Jacobian for latent traversals')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
flags.DEFINE_integer('batch_size', 32, 'Size of inputs')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('train', True, 'Whether to train or test network')
flags.DEFINE_bool('single', False, 'whether to use one sample to debug')
flags.DEFINE_bool('cclass', True, 'whether to use a conditional model (required for task label)')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_float('step_lr', 10.0, 'step size for updates on label')
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
flags.DEFINE_bool('large_model', False, 'Whether to use a large model')
flags.DEFINE_bool('larger_model', False, 'Whether to use a larger model')
flags.DEFINE_bool('wider_model', False, 'Whether to use a widermodel model')
flags.DEFINE_bool('svhn', False, 'Whether to test on SVHN')
# Conditions for mixenergy (outlier detection)
flags.DEFINE_bool('svhnmix', False, 'Whether to test mix on SVHN')
flags.DEFINE_bool('cifar100mix', False, 'Whether to test mix on CIFAR100')
flags.DEFINE_bool('texturemix', False, 'Whether to test mix on Textures dataset')
flags.DEFINE_bool('randommix', False, 'Whether to test mix on random dataset')
# Conditions for label task (adversarial classification)
flags.DEFINE_integer('lival', 8, 'Value of constraint for li')
flags.DEFINE_integer('l2val', 40, 'Value of constraint for l2')
flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run')
flags.DEFINE_integer('lnorm', -1, 'linfinity is -1, l2 norm is 2')
flags.DEFINE_bool('labelgrid', False, 'Make a grid of labels')
# Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
flags.DEFINE_bool('cond_size', True, 'whether to condition on scale')
FLAGS = flags.FLAGS
def rescale_im(im):
im = np.clip(im, 0, 1)
return np.round(im * 255).astype(np.uint8)
def label(dataloader, test_dataloader, target_vars, sess, l1val=8, l2val=40):
X = target_vars['X']
Y = target_vars['Y']
Y_GT = target_vars['Y_GT']
accuracy = target_vars['accuracy']
train_op = target_vars['train_op']
l1_norm = target_vars['l1_norm']
l2_norm = target_vars['l2_norm']
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
label_init = label_init / label_init.sum(axis=1, keepdims=True)
label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
label_init = np.reshape(label_init, (-1, 10))
for i in range(1):
emp_accuracies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
emp_accuracy = sess.run([accuracy], feed_dict)
emp_accuracies.append(emp_accuracy)
print(np.array(emp_accuracies).mean())
print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
return np.array(emp_accuracies).mean()
def labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=8, l2val=40):
X = target_vars['X']
Y = target_vars['Y']
Y_GT = target_vars['Y_GT']
accuracy = target_vars['accuracy']
train_op = target_vars['train_op']
l1_norm = target_vars['l1_norm']
l2_norm = target_vars['l2_norm']
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
label_init = label_init / label_init.sum(axis=1, keepdims=True)
label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
label_init = np.reshape(label_init, (-1, 10))
itr = 0
if FLAGS.train:
for i in range(1):
for data_corrupt, data, label_gt in tqdm(dataloader):
feed_dict = {X: data, Y_GT: label_gt, Y: label_init}
acc, _ = sess.run([accuracy, train_op], feed_dict)
itr += 1
if itr % 10 == 0:
print(acc)
saver.save(sess, osp.join(savedir, "model_supervised"))
saver.restore(sess, osp.join(savedir, "model_supervised"))
for i in range(1):
emp_accuracies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
emp_accuracy = sess.run([accuracy], feed_dict)
emp_accuracies.append(emp_accuracy)
print(np.array(emp_accuracies).mean())
print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
return np.array(emp_accuracies).mean()
def energyeval(dataloader, test_dataloader, target_vars, sess):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
energy = target_vars['energy']
energy_end = target_vars['energy_end']
test_energies = []
train_energies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
feed_dict = {X: data, Y_GT: label_gt}
test_energy = sess.run([energy], feed_dict)[0]
test_energies.extend(list(test_energy))
for data_corrupt, data, label_gt in tqdm(dataloader):
feed_dict = {X: data, Y_GT: label_gt}
train_energy = sess.run([energy], feed_dict)[0]
train_energies.extend(list(train_energy))
print(len(train_energies))
print(len(test_energies))
print("Train energies of {} with std {}".format(np.mean(train_energies), np.std(train_energies)))
print("Test energies of {} with std {}".format(np.mean(test_energies), np.std(test_energies)))
np.save("train_ebm.npy", train_energies)
np.save("test_ebm.npy", test_energies)
def energyevalmix(dataloader, test_dataloader, target_vars, sess):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
energy = target_vars['energy']
if FLAGS.svhnmix:
dataset = Svhn(train=False)
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
test_iter = iter(test_dataloader_val)
elif FLAGS.cifar100mix:
dataset = Cifar100(train=False)
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
test_iter = iter(test_dataloader_val)
elif FLAGS.texturemix:
dataset = Textures()
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
test_iter = iter(test_dataloader_val)
probs = []
labels = []
negs = []
pos = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data = data.numpy()
data_corrupt = data_corrupt.numpy()
if FLAGS.svhnmix:
_, data_mix, _ = test_iter.next()
elif FLAGS.cifar100mix:
_, data_mix, _ = test_iter.next()
elif FLAGS.texturemix:
_, data_mix, _ = test_iter.next()
elif FLAGS.randommix:
data_mix = np.random.randn(FLAGS.batch_size, 32, 32, 3) * 0.5 + 0.5
else:
data_idx = np.concatenate([np.arange(1, data.shape[0]), [0]])
data_other = data[data_idx]
data_mix = (data + data_other) / 2
data_mix = data_mix[:data.shape[0]]
if FLAGS.cclass:
# It's unfair to take a random class
label_gt= np.tile(np.eye(10), (data.shape[0], 1, 1))
label_gt = label_gt.reshape(data.shape[0] * 10, 10)
data_mix = np.tile(data_mix[:, None, :, :, :], (1, 10, 1, 1, 1))
data = np.tile(data[:, None, :, :, :], (1, 10, 1, 1, 1))
data_mix = data_mix.reshape(-1, 32, 32, 3)
data = data.reshape(-1, 32, 32, 3)
feed_dict = {X: data, Y_GT: label_gt}
feed_dict_neg = {X: data_mix, Y_GT: label_gt}
pos_energy = sess.run([energy], feed_dict)[0]
neg_energy = sess.run([energy], feed_dict_neg)[0]
if FLAGS.cclass:
pos_energy = pos_energy.reshape(-1, 10).min(axis=1)
neg_energy = neg_energy.reshape(-1, 10).min(axis=1)
probs.extend(list(-1*pos_energy))
probs.extend(list(-1*neg_energy))
pos.extend(list(-1*pos_energy))
negs.extend(list(-1*neg_energy))
labels.extend([1]*pos_energy.shape[0])
labels.extend([0]*neg_energy.shape[0])
pos, negs = np.array(pos), np.array(negs)
np.save("pos.npy", pos)
np.save("neg.npy", negs)
auroc = sk.roc_auc_score(labels, probs)
print("Roc score of {}".format(auroc))
def anticorrupt(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
noise = np.random.uniform(0, 1, size=[data.shape[0], data.shape[1], data.shape[2]])
low_mask = noise < 0.05
high_mask = (noise > 0.05) & (noise < 0.1)
print(high_mask.shape)
data_corrupt = data.copy()
data_corrupt[low_mask] = 0.1
data_corrupt[high_mask] = 0.9
data_corrupt_init = data_corrupt
for i in range(5):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
data_uncorrupt = data_corrupt
data_corrupt, data_uncorrupt, data = rescale_im(data_corrupt_init), rescale_im(data_uncorrupt), rescale_im(data)
panel_im = np.zeros((32*20, 32*3, 3)).astype(np.uint8)
for i in range(20):
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
panel_im[32*i:32*i+32, 32:64] = data_uncorrupt[i]
panel_im[32*i:32*i+32, 64:] = data[i]
imsave(osp.join(logdir, "anticorrupt.png"), panel_im)
assert False
def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
eval_im = 10000
data_diff = []
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_uncorrupts = []
data_corrupt = data.copy()
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
data_corrupt_init = data_corrupt
for j in range(10):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
val = np.mean(np.square(data_corrupt - data), axis=(1, 2, 3))
data_diff.extend(list(val))
if len(data_diff) > eval_im:
break
print("Mean {} and std {} for train dataloader".format(np.mean(data_diff), np.std(data_diff)))
np.save("data_diff_train_image.npy", data_diff)
data_diff = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_uncorrupts = []
data_corrupt = data.copy()
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
data_corrupt_init = data_corrupt
for j in range(10):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
data_diff.extend(list(np.mean(np.square(data_corrupt - data), axis=(1, 2, 3))))
if len(data_diff) > eval_im:
break
print("Mean {} and std {} for test dataloader".format(np.mean(data_diff), np.std(data_diff)))
np.save("data_diff_test_image.npy", data_diff)
def crossclass(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_mods, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data.copy()
data_corrupt[1:] = data_corrupt[0:-1]
data_corrupt[0] = data[-1]
data_mods = []
data_mod = data_corrupt
for i in range(10):
data_mods.append(data_mod)
feed_dict = {X: data_mod, Y_GT: label_gt}
data_mod = sess.run(X_final, feed_dict)
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
panel_im = np.zeros((32*20, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
for i in range(20):
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
for j in range(len(data_mods)):
panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
panel_im[32*i:32*i+32, -32:] = data[i]
imsave(osp.join(logdir, "crossclass.png"), panel_im)
assert False
def cycleclass(dataloader, weights, model, target_vars, logdir, sess):
# X, Y_GT, X_final, X_targ = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'], target_vars['X_targ']
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data_corrupt.numpy()
data_mods = []
x_curr = data_corrupt
x_target = np.random.uniform(0, 1, data_corrupt.shape)
# x_target = np.tile(x_target, (1, 32, 32, 1))
for i in range(20):
feed_dict = {X: x_curr, Y_GT: label_gt}
x_curr_new = sess.run(X_final, feed_dict)
x_curr = x_curr_new
data_mods.append(x_curr_new)
if i > 30:
x_target = np.random.uniform(0, 1, data_corrupt.shape)
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
panel_im = np.zeros((32*100, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
for i in range(100):
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
for j in range(len(data_mods)):
panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
panel_im[32*i:32*i+32, -32:] = data[i]
imsave(osp.join(logdir, "cycleclass.png"), panel_im)
assert False
def democlass(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
panel_im = np.zeros((5*32, 10*32, 3)).astype(np.uint8)
for i in range(10):
data_corrupt = np.random.uniform(0, 1, (5, 32, 32, 3))
label_gt = np.tile(np.eye(10)[i:i+1], (5, 1))
feed_dict = {X: data_corrupt, Y_GT: label_gt}
x_final = sess.run([X_final], feed_dict)[0]
x_final = rescale_im(x_final)
row = i // 2
col = i % 2
start_idx = col * 32 * 5
row_idx = row * 32
for j in range(5):
panel_im[row_idx:row_idx+32, start_idx+j*32:start_idx+(j+1) * 32] = x_final[j]
imsave(osp.join(logdir, "democlass.png"), panel_im)
def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars):
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
def compute_logit(X, stop_grad=False, num_steps=0):
batch_size = tf.shape(X)[0]
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
Y_new = tf.reshape(Y, (batch_size*10, 10))
X_min = X - 8 / 255.
X_max = X + 8 / 255.
for i in range(num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
energy_noise = model.forward(X, weights, label=Y, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad
X = tf.maximum(tf.minimum(X, X_max), X_min)
energy = model.forward(X, weight, label=Y_new)
energy = -tf.reshape(energy, (batch_size, 10))
if stop_grad:
energy = tf.stop_gradient(energy)
return energy
for i in range(FLAGS.pgd):
if FLAGS.train:
break
print("Constructed loop {} of pgd attack".format(i))
X_init = X
if i == 0:
X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
logit = compute_logit(X)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
X = X + 2 * x_grad
if FLAGS.lnorm == -1:
X = tf.maximum(tf.minimum(X, X_max), X_min)
elif FLAGS.lnorm == 2:
X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
energy = compute_logit(X, num_steps=0)
logits = energy
labels = tf.argmax(Y_GT, axis=1)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logits)
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(loss)
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, axis=1), labels)
target_vars['accuracy'] = accuracy
target_vars['train_op'] = train_op
target_vars['l1_norm'] = l1_norm
target_vars['l2_norm'] = l2_norm
def construct_latent(weights, X, Y_GT, model, target_vars):
eps = 0.001
X_init = X[0:1]
def traversals(model, X, weights, Y_GT):
if FLAGS.hessian:
e_pos = model.forward(X, weights, label=Y_GT)
hessian = tf.hessians(e_pos, X)
hessian = tf.reshape(hessian, (1, 64*64, 64*64))[0]
e, v = tf.linalg.eigh(hessian)
else:
latent = model.forward(X, weights, label=Y_GT, return_logit=True)
latents = tf.split(latent, 128, axis=1)
jacobian = [tf.gradients(latent, X)[0] for latent in latents]
jacobian = tf.stack(jacobian, axis=1)
jacobian = tf.reshape(jacobian, (tf.shape(jacobian)[1], tf.shape(jacobian)[1], 64*64))
s, _, v = tf.linalg.svd(jacobian)
return v
var_scale = 1.0
n = 3
xs = []
v = traversals(model, X_init, weights, Y_GT)
for i in range(n):
var = tf.reshape(v[:, i], (1, 64, 64))
X_plus = X_init - var_scale * var
X_min = X_init + var_scale * var
xs.extend([X_plus, X_min])
x_stack = tf.stack(xs, axis=0)
e_pos_hess_modify = model.forward(x_stack, weights, label=Y_GT)
for i in range(20):
x_stack = x_stack + tf.random_normal(tf.shape(x_stack), mean=0.0, stddev=0.005)
e_pos = model.forward(x_stack, weights, label=Y_GT)
x_grad = tf.gradients(e_pos, [x_stack])[0]
x_stack = x_stack - 4*FLAGS.step_lr * x_grad
x_stack = tf.clip_by_value(x_stack, 0, 1)
x_mods = tf.split(X, 6)
eigs = []
for j in range(6):
x_mod = x_mods[j]
v = traversals(model, x_mod, weights, Y_GT)
idx = j // 2
var = tf.reshape(v[:, idx], (1, 64, 64))
if j % 2 == 1:
x_mod = x_mod + var_scale * var
eigs.append(var)
else:
x_mod = x_mod - var_scale * var
eigs.append(-var)
x_mod = tf.clip_by_value(x_mod, 0, 1)
x_mods[j] = x_mod
x_mods_stack = tf.stack(x_mods, axis=0)
eigs_stack = tf.stack(eigs, axis=0)
energys = []
for i in range(20):
x_mods_stack = x_mods_stack + tf.random_normal(tf.shape(x_mods_stack), mean=0.0, stddev=0.005)
e_pos = model.forward(x_mods_stack, weights, label=Y_GT)
x_grad = tf.gradients(e_pos, [x_mods_stack])[0]
x_mods_stack = x_mods_stack - 4*FLAGS.step_lr * x_grad
# x_mods_stack = x_mods_stack + 0.1 * eigs_stack
x_mods_stack = tf.clip_by_value(x_mods_stack, 0, 1)
energys.append(e_pos)
x_refine = x_mods_stack
es = tf.stack(energys, axis=0)
# target_vars['hessian'] = hessian
# target_vars['e'] = e
target_vars['v'] = v
target_vars['x_stack'] = x_stack
target_vars['x_refine'] = x_refine
target_vars['es'] = es
# target_vars['e_base'] = e_pos_base
def latent(test_dataloader, weights, model, target_vars, sess):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
# hessian = target_vars['hessian']
# e = target_vars['e']
v = target_vars['v']
x_stack = target_vars['x_stack']
x_refine = target_vars['x_refine']
es = target_vars['es']
# e_pos_base = target_vars['e_base']
# e_pos_hess_modify = target_vars['e_pos_hessian']
data_corrupt, data, label_gt = iter(test_dataloader).next()
data = data.numpy()
x_init = np.tile(data[0:1], (6, 1, 1))
x_mod, = sess.run([x_stack], {X: data})
# print("Value of original starting image: ", e_pos)
# print("Value of energy of hessian: ", e_pos_hess)
x_mod = x_mod.squeeze()
n = 6
x_mod_list = [x_init, x_mod]
for i in range(n):
x_mod, evals = sess.run([x_refine, es], {X: x_mod})
x_mod = x_mod.squeeze()
x_mod_list.append(x_mod)
print("Value of energies after evaluation: ", evals)
x_mod_list = x_mod_list[:]
series_xmod = np.stack(x_mod_list, axis=1)
series_header = np.tile(data[0:1, None, :, :], (1, len(x_mod_list), 1, 1))
series_total = np.concatenate([series_header, series_xmod], axis=0)
series_total_full = np.ones((*series_total.shape[:-2], 66, 66))
series_total_full[:, :, 1:-1, 1:-1] = series_total
series_total = series_total_full
series_total = series_total.transpose((0, 2, 1, 3)).reshape((-1, len(x_mod_list)*66))
im_total = rescale_im(series_total)
imsave("latent_comb.png", im_total)
def construct_label(weights, X, Y, Y_GT, model, target_vars):
# for i in range(FLAGS.num_steps):
# Y = Y + tf.random_normal(tf.shape(Y), mean=0.0, stddev=0.03)
# e = model.forward(X, weights, label=Y)
# Y_grad = tf.clip_by_value(tf.gradients(e, [Y])[0], -1, 1)
# Y = Y - 0.1 * Y_grad
# Y = tf.clip_by_value(Y, 0, 1)
# Y = Y / tf.reduce_sum(Y, axis=[1], keepdims=True)
e_bias = tf.get_variable('e_bias', shape=10, initializer=tf.initializers.zeros())
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
def compute_logit(X, stop_grad=False, num_steps=0):
batch_size = tf.shape(X)[0]
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
Y_new = tf.reshape(Y, (batch_size*10, 10))
X_min = X - 8 / 255.
X_max = X + 8 / 255.
for i in range(num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
energy_noise = model.forward(X, weights, label=Y, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad
X = tf.maximum(tf.minimum(X, X_max), X_min)
energy = model.forward(X, weights, label=Y_new)
energy = -tf.reshape(energy, (batch_size, 10))
if stop_grad:
energy = tf.stop_gradient(energy)
return energy
# eps_norm = 30
X_min = X - l1_norm / 255.
X_max = X + l1_norm / 255.
for i in range(FLAGS.pgd):
print("Constructed loop {} of pgd attack".format(i))
X_init = X
if i == 0:
X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
logit = compute_logit(X)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
X = X + 2 * x_grad
if FLAGS.lnorm == -1:
X = tf.maximum(tf.minimum(X, X_max), X_min)
elif FLAGS.lnorm == 2:
X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
energy_stopped = compute_logit(X, stop_grad=True, num_steps=FLAGS.num_steps) + e_bias
# # Y = tf.Print(Y, [Y])
labels = tf.argmax(Y_GT, axis=1)
# max_z = tf.argmax(energy_stopped, axis=1)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=energy_stopped)
optimizer = tf.train.AdamOptimizer(1e-2)
train_op = optimizer.minimize(loss)
accuracy = tf.contrib.metrics.accuracy(tf.argmax(energy_stopped, axis=1), labels)
target_vars['accuracy'] = accuracy
target_vars['train_op'] = train_op
target_vars['l1_norm'] = l1_norm
target_vars['l2_norm'] = l2_norm
def construct_energy(weights, X, Y, Y_GT, model, target_vars):
energy = model.forward(X, weights, label=Y_GT)
for i in range(FLAGS.num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad
X = tf.clip_by_value(X, 0, 1)
target_vars['energy'] = energy
target_vars['energy_end'] = energy_noise
def construct_steps(weights, X, Y_GT, model, target_vars):
n = 50
scale_fac = 1.0
# if FLAGS.task == 'cycleclass':
# scale_fac = 10.0
X_mods = []
X = tf.identity(X)
mask = np.zeros((1, 32, 32, 3))
if FLAGS.task == "boxcorrupt":
mask[:, 16:, :, :] = 1
else:
mask[:, :, :, :] = 1
mask = tf.Variable(tf.convert_to_tensor(mask, dtype=tf.float32), trainable=False)
# X_targ = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
for i in range(FLAGS.num_steps):
X_old = X
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005*scale_fac) * mask
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad * scale_fac * mask
X = tf.clip_by_value(X, 0, 1)
if i % n == (n-1):
X_mods.append(X)
print("Constructing step {}".format(i))
target_vars['X_final'] = X
target_vars['X_mods'] = X_mods
def nearest_neighbor(dataset, sess, target_vars, logdir):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
x_final = target_vars['X_final']
noise = np.random.uniform(0, 1, size=[10, 32, 32, 3])
# label = np.random.randint(0, 10, size=[10])
label = np.eye(10)
coarse = noise
for i in range(10):
x_new = sess.run([x_final], {X:coarse, Y_GT:label})[0]
coarse = x_new
x_new_dense = x_new.reshape(10, 1, 32*32*3)
dataset_dense = dataset.reshape(1, 50000, 32*32*3)
diff = np.square(x_new_dense - dataset_dense).sum(axis=2)
diff_idx = np.argsort(diff, axis=1)
panel = np.zeros((32*10, 32*6, 3))
dataset_rescale = rescale_im(dataset)
x_new_rescale = rescale_im(x_new)
for i in range(10):
panel[i*32:i*32+32, :32] = x_new_rescale[i]
for j in range(5):
panel[i*32:i*32+32, 32*j+32:32*j+64] = dataset_rescale[diff_idx[i, j]]
imsave(osp.join(logdir, "nearest.png"), panel)
def main():
if FLAGS.dataset == "cifar10":
dataset = Cifar10(train=True, noise=False)
test_dataset = Cifar10(train=False, noise=False)
else:
dataset = Imagenet(train=True)
test_dataset = Imagenet(train=False)
if FLAGS.svhn:
dataset = Svhn(train=True)
test_dataset = Svhn(train=False)
if FLAGS.task == 'latent':
dataset = DSprites()
test_dataset = dataset
dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
hidden_dim = 128
if FLAGS.large_model:
model = ResNet32Large(num_filters=hidden_dim)
elif FLAGS.larger_model:
model = ResNet32Larger(num_filters=hidden_dim)
elif FLAGS.wider_model:
if FLAGS.dataset == 'imagenet':
model = ResNet32Wider(num_filters=196, train=False)
else:
model = ResNet32Wider(num_filters=256, train=False)
else:
model = ResNet32(num_filters=hidden_dim)
if FLAGS.task == 'latent':
model = DspritesNet()
weights = model.construct_weights('context_{}'.format(0))
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Model has a total of {} parameters".format(total_parameters))
config = tf.ConfigProto()
sess = tf.InteractiveSession()
if FLAGS.task == 'latent':
X = tf.placeholder(shape=(None, 64, 64), dtype = tf.float32)
else:
X = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
if FLAGS.dataset == "cifar10":
Y = tf.placeholder(shape=(None, 10), dtype = tf.float32)
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
elif FLAGS.dataset == "imagenet":
Y = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}
if FLAGS.task == 'label':
construct_label(weights, X, Y, Y_GT, model, target_vars)
elif FLAGS.task == 'labelfinetune':
construct_finetune_label(weights, X, Y, Y_GT, model, target_vars, )
elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
construct_energy(weights, X, Y, Y_GT, model, target_vars)
elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
construct_steps(weights, X, Y_GT, model, target_vars)
elif FLAGS.task == 'latent':
construct_latent(weights, X, Y_GT, model, target_vars)
sess.run(tf.global_variables_initializer())
saver = loader = tf.train.Saver(max_to_keep=10)
savedir = osp.join('cachedir', FLAGS.exp)
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if not osp.exists(logdir):
os.makedirs(logdir)
initialize()
if FLAGS.resume_iter != -1:
model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
optimistic_restore(sess, model_file)
# saver.restore(sess, model_file)
else:
# optimistic_restore(sess, model_file)
saver.restore(sess, model_file)
if FLAGS.task == 'label':
if FLAGS.labelgrid:
vals = []
if FLAGS.lnorm == -1:
for i in range(31):
accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i)
vals.append(accuracies)
elif FLAGS.lnorm == 2:
for i in range(0, 100, 5):
accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i)
vals.append(accuracies)
np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
else:
label(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'labelfinetune':
labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val)
elif FLAGS.task == 'energyeval':
energyeval(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'mixenergy':
energyevalmix(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'anticorrupt':
anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'boxcorrupt':
# boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'crossclass':
crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'cycleclass':
cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'democlass':
democlass(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'nearestneighbor':
# print(dir(dataset))
# print(type(dataset))
nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir)
elif FLAGS.task == 'latent':
latent(test_dataloader, weights, model, target_vars, sess)
if __name__ == "__main__":
main()

292
fid.py Normal file
View File

@ -0,0 +1,292 @@
#!/usr/bin/env python3
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
of these distributions, while the 2nd distribution is given by a GAN.
When run as a stand-alone program, it compares the distribution of
images that are stored as PNG/JPEG at a specified location with a
distribution given by summary statistics (in pickle format).
The FID is calculated by assuming that X_1 and X_2 are the activations of
the pool_3 layer of the inception net for generated samples and real world
samples respectivly.
See --help to see further details.
'''
from __future__ import absolute_import, division, print_function
import numpy as np
import os
import gzip, pickle
import tensorflow as tf
from scipy.misc import imread
from scipy import linalg
import pathlib
import urllib
import tarfile
import warnings
MODEL_DIR = '/tmp/imagenet'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
pool3 = None
class InvalidFIDException(Exception):
pass
#-------------------------------------------------------------------------------
def get_fid_score(images, images_gt):
images = np.stack(images, 0)
images_gt = np.stack(images_gt, 0)
with tf.Session() as sess:
m1, s1 = calculate_activation_statistics(images, sess)
m2, s2 = calculate_activation_statistics(images_gt, sess)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
print("Obtained fid value of {}".format(fid_value))
return fid_value
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile( pth, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString( f.read())
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
#-------------------------------------------------------------------------------
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """
layername = 'FID_Inception_Net/pool_3:0'
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
return pool3
#-------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=50, verbose=False):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 256.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the disposable hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- A numpy array of dimension (num images, 2048) that contains the
activations of the given tensor when feeding inception with the query tensor.
"""
# inception_layer = _get_inception_layer(sess)
d0 = images.shape[0]
if batch_size > d0:
print("warning: batch size is bigger than the data size. setting batch size to data size")
batch_size = d0
n_batches = d0//batch_size
n_used_imgs = n_batches*batch_size
pred_arr = np.empty((n_used_imgs,2048))
for i in range(n_batches):
if verbose:
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
start = i*batch_size
end = start + batch_size
batch = images[start:end]
pred = sess.run(pool3, {'ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size,-1)
if verbose:
print(" done")
return pred_arr
#-------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of the pool_3 layer of the
inception net ( like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
on an representive data set.
-- sigma1: The covariance matrix over activations of the pool_3 layer for
generated samples.
-- sigma2: The covariance matrix over activations of the pool_3 layer,
precalcualted on an representive data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
#-------------------------------------------------------------------------------
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
"""Calculation of the statistics used by the FID.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 255.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the available hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the incption model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the incption model.
"""
act = get_activations(images, sess, batch_size, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# The following functions aren't needed for calculating the FID
# they're just here to make this module work as a stand-alone script
# for calculating FID scores
#-------------------------------------------------------------------------------
def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads
the file if it is not present. '''
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
if inception_path is None:
inception_path = '/tmp'
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
if not model_file.exists():
print("Downloading Inception model")
from urllib import request
import tarfile
fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
return str(model_file)
def _handle_path(path, sess):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
f.close()
else:
path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
m, s = calculate_activation_statistics(x, sess)
return m, s
def calculate_fid_given_paths(paths, inception_path):
''' Calculates the FID of two paths. '''
inception_path = check_or_download_inception(inception_path)
for p in paths:
if not os.path.exists(p):
raise RuntimeError("Invalid path: %s" % p)
create_inception_graph(str(inception_path))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = _handle_path(paths[0], sess)
m2, s2 = _handle_path(paths[1], sess)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
def _init_inception():
global pool3
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
with tf.gfile.FastGFile(os.path.join(
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size.
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
if pool3 is None:
_init_inception()

129
hmc.py Normal file
View File

@ -0,0 +1,129 @@
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
FLAGS = flags.FLAGS
def kinetic_energy(velocity):
"""Kinetic energy of the current velocity (assuming a standard Gaussian)
(x dot x) / 2
Parameters
----------
velocity : tf.Variable
Vector of current velocity
Returns
-------
kinetic_energy : float
"""
return 0.5 * tf.square(velocity)
def hamiltonian(position, velocity, energy_function):
"""Computes the Hamiltonian of the current position, velocity pair
H = U(x) + K(v)
U is the potential energy and is = -log_posterior(x)
Parameters
----------
position : tf.Variable
Position or state vector x (sample from the target distribution)
velocity : tf.Variable
Auxiliary velocity variable
energy_function
Function from state to position to 'energy'
= -log_posterior
Returns
-------
hamitonian : float
"""
batch_size = tf.shape(velocity)[0]
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
def leapfrog_step(x0,
v0,
neg_log_posterior,
step_size,
num_steps):
# Start by updating the velocity a half-step
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
# Initalize x to be the first step
x = x0 + step_size * v
for i in range(num_steps):
# Compute gradient of the log-posterior with respect to x
gradient = tf.gradients(neg_log_posterior(x), x)[0]
# Update velocity
v = v - step_size * gradient
# x_clip = tf.clip_by_value(x, 0.0, 1.0)
# x = x_clip
# v_mask = 1 - 2 * tf.abs(tf.sign(x - x_clip))
# v = v * v_mask
# Update x
x = x + step_size * v
# x = tf.clip_by_value(x, -0.01, 1.01)
# x = tf.Print(x, [tf.reduce_min(x), tf.reduce_max(x), tf.reduce_mean(x)])
# Do a final update of the velocity for a half step
v = v - 0.5 * step_size * tf.gradients(neg_log_posterior(x), x)[0]
# return new proposal state
return x, v
def hmc(initial_x,
step_size,
num_steps,
neg_log_posterior):
"""Summary
Parameters
----------
initial_x : tf.Variable
Initial sample x ~ p
step_size : float
Step-size in Hamiltonian simulation
num_steps : int
Number of steps to take in Hamiltonian simulation
neg_log_posterior : str
Negative log posterior (unnormalized) for the target distribution
Returns
-------
sample :
Sample ~ target distribution
"""
v0 = tf.random_normal(tf.shape(initial_x))
x, v = leapfrog_step(initial_x,
v0,
step_size=step_size,
num_steps=num_steps,
neg_log_posterior=neg_log_posterior)
orig = hamiltonian(initial_x, v0, neg_log_posterior)
current = hamiltonian(x, v, neg_log_posterior)
prob_accept = tf.exp(orig - current)
if FLAGS.proposal_debug:
prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
uniform = tf.random_uniform(tf.shape(prob_accept))
keep_mask = (prob_accept > uniform)
# print(keep_mask.get_shape())
x_new = tf.where(keep_mask, x, initial_x)
return x_new

73
imagenet_demo.py Normal file
View File

@ -0,0 +1,73 @@
from models import ResNet128
import numpy as np
import os.path as osp
from tensorflow.python.platform import flags
import tensorflow as tf
import imageio
from utils import optimistic_restore
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
flags.DEFINE_bool('cclass', True, 'conditional models')
flags.DEFINE_bool('use_attention', False, 'using attention')
FLAGS = flags.FLAGS
def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8)
if __name__ == "__main__":
model = ResNet128(num_filters=64)
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
sess = tf.InteractiveSession()
weights = model.construct_weights("context_0")
x_mod = X_NOISE
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
mean=0.0,
stddev=0.005)
energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
reuse=True, stop_at_grad=False, stop_batch=True)
x_grad = tf.gradients(energy_noise, [x_mod])[0]
energy_noise_old = energy_noise
lr = FLAGS.step_lr
x_last = x_mod - (lr) * x_grad
x_mod = x_last
x_mod = tf.clip_by_value(x_mod, 0, 1)
x_output = x_mod
sess.run(tf.global_variables_initializer())
saver = loader = tf.train.Saver()
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
saver.restore(sess, model_file)
lx = np.random.permutation(1000)[:16]
ims = []
# What to initialize sampling with.
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
labels = np.eye(1000)[lx]
for i in range(FLAGS.num_steps):
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
imageio.mimwrite('sample.gif', ims)

337
imagenet_preprocessing.py Normal file
View File

@ -0,0 +1,337 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image pre-processing utilities.
"""
import tensorflow as tf
IMAGE_DEPTH = 3 # color images
import tensorflow as tf
# _R_MEAN = 123.68
# _G_MEAN = 116.78
# _B_MEAN = 103.94
# _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
_CHANNEL_MEANS = [0.0, 0.0, 0.0]
# The lower bound for the smallest side of the image for aspect-preserving
# resizing. For example, if an image is 500 x 1000, it will be resized to
# _RESIZE_MIN x (_RESIZE_MIN * 2).
_RESIZE_MIN = 128
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
"""Crops the given image to a random part of the image, and randomly flips.
We use the fused decode_and_crop op, which performs better than the two ops
used separately in series, but note that this requires that the image be
passed in as an un-decoded string Tensor.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding.
Returns:
3-D tensor with cropped image.
"""
# A large fraction of image datasets contain a human-annotated bounding box
# delineating the region of the image containing the object of interest. We
# choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.image.extract_jpeg_shape(image_buffer),
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0],
max_attempts=100,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Reassemble the bounding box in the format the crop op requires.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
# Use the fused decode and crop op here, which is faster than each in series.
cropped = tf.image.decode_and_crop_jpeg(
image_buffer, crop_window, channels=num_channels)
# Flip to add a little more random distortion in.
cropped = tf.image.random_flip_left_right(cropped)
return cropped
def _central_crop(image, crop_height, crop_width):
"""Performs central crops of the given image list.
Args:
image: a 3-D image tensor
crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop.
Returns:
3-D tensor with cropped image.
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]
amount_to_be_cropped_h = (height - crop_height)
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width)
crop_left = amount_to_be_cropped_w // 2
return tf.slice(
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels):
"""Subtracts the given means from each image channel.
For example:
means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means)
Note that the rank of `image` must be known.
Args:
image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
num_channels: number of color channels in the image that will be distorted.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if image.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
# We have a 1-D tensor of means; convert to 3-D.
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
return image - means
def _smallest_size_at_least(height, width, resize_min):
"""Computes new shape with the smallest side equal to `smallest_side`.
Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio.
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
new_height: an int32 scalar tensor indicating the new height.
new_width: an int32 scalar tensor indicating the new width.
"""
resize_min = tf.cast(resize_min, tf.float32)
# Convert to floats to make subsequent calculations go smoothly.
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
smaller_dim = tf.minimum(height, width)
scale_ratio = resize_min / smaller_dim
# Convert back to ints to make heights and widths that TF ops will accept.
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
return new_height, new_width
def _aspect_preserving_resize(image, resize_min):
"""Resize images preserving the original aspect ratio.
Args:
image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
return _resize_image(image, new_height, new_width)
def _resize_image(image, height, width):
"""Simple wrapper around tf.resize_images.
This is primarily to make sure we use the same `ResizeMethod` and other
details each time.
Args:
image: A 3-D image `Tensor`.
height: The target height for the resized image.
width: The target width for the resized image.
Returns:
resized_image: A 3-D tensor containing the resized image. The first two
dimensions have the shape [height, width].
"""
return tf.image.resize_images(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
align_corners=False)
def preprocess_image(image_buffer, bbox, output_height, output_width,
num_channels, is_training=False):
"""Preprocesses the given image.
Preprocessing includes decoding, cropping, and resizing for both training
and eval images. Training preprocessing, however, introduces some random
distortion of the image to improve accuracy.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding.
is_training: `True` if we're preprocessing the image for training and
`False` otherwise.
Returns:
A preprocessed image.
"""
if is_training:
# For training, we want to randomize some of the distortions.
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
image = _resize_image(image, output_height, output_width)
else:
# For validation, we want to decode, resize, then just crop the middle.
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
image = _aspect_preserving_resize(image, _RESIZE_MIN)
print(image)
image = _central_crop(image, output_height, output_width)
image.set_shape([output_height, output_width, num_channels])
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains
the following fields:
image/height: 462
image/width: 581
image/colorspace: 'RGB'
image/channels: 3
image/class/label: 615
image/class/synset: 'n03623198'
image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6
image/object/bbox/label: 615
image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string>
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file.
label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
text: Tensor tf.string containing the human-readable label.
"""
# Dense features in Example proto.
feature_map = {
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
default_value=-1),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
}
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update(
{k: sparse_float32 for k in ['image/object/bbox/xmin',
'image/object/bbox/ymin',
'image/object/bbox/xmax',
'image/object/bbox/ymax']})
features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
# Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
# Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1])
return features['image/encoded'], label, bbox, features['image/class/text']
class ImagenetPreprocessor:
def __init__(self, image_size, dtype, train):
self.image_size = image_size
self.dtype = dtype
self.train = train
def preprocess(self, image_buffer, bbox):
# pylint: disable=g-import-not-at-top
image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
return tf.cast(image, self.dtype)
def parse_and_preprocess(self, value):
image_buffer, label_index, bbox, _ = parse_example_proto(value)
image = self.preprocess(image_buffer, bbox)
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
return label_index, image

105
inception.py Normal file
View File

@ -0,0 +1,105 @@
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
import glob
import scipy.misc
import math
import sys
import horovod.tensorflow as hvd
MODEL_DIR = '/tmp/imagenet'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
softmax = None
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
sess = tf.Session(config=config)
# Call this function with list of images. Each of elements should be a
# numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10):
# For convenience
if len(images[0].shape) != 3:
return 0, 0
# Bypassing all the assertions so that we don't end prematuraly'
# assert(type(images) == list)
# assert(type(images[0]) == np.ndarray)
# assert(len(images[0].shape) == 3)
# assert(np.max(images[0]) > 10)
# assert(np.min(images[0]) >= 0.0)
inps = []
for img in images:
img = img.astype(np.float32)
inps.append(np.expand_dims(img, 0))
bs = 1
preds = []
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
for i in range(n_batches):
sys.stdout.write(".")
sys.stdout.flush()
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
inp = np.concatenate(inp, 0)
pred = sess.run(softmax, {'ExpandDims:0': inp})
preds.append(pred)
preds = np.concatenate(preds, 0)
scores = []
for i in range(splits):
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
# This function is called automatically.
def _init_inception():
global softmax
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
with tf.gfile.FastGFile(os.path.join(
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size.
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.set_shape(tf.TensorShape(new_shape))
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax = tf.nn.softmax(logits)
if softmax is None:
_init_inception()

622
models.py Normal file
View File

@ -0,0 +1,622 @@
import tensorflow as tf
from tensorflow.python.platform import flags
import numpy as np
from utils import conv_block, get_weight, attention, conv_cond_concat, init_conv_weight, init_attention_weight, init_res_weight, smart_res_block, smart_res_block_optim, init_convt_weight
from utils import init_fc_weight, smart_conv_block, smart_fc_block, smart_atten_block, groupsort, smart_convt_block, swish
flags.DEFINE_bool('swish_act', False, 'use the swish activation for dsprites')
FLAGS = flags.FLAGS
class MnistNet(object):
def __init__(self, num_channels=1, num_filters=64):
self.channels = num_channels
self.dim_hidden = num_filters
self.datasource = FLAGS.datasource
if FLAGS.cclass:
self.label_size = 10
else:
self.label_size = 0
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
classes = 1
with tf.variable_scope(scope):
init_conv_weight(weights, 'c1_pre', 3, 1, 64)
init_conv_weight(weights, 'c1', 4, 64, self.dim_hidden, classes=classes)
init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc_dense', 4*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
if FLAGS.cclass:
self.label_size = 10
else:
self.label_size = 0
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, **kwargs):
channels = self.channels
weights = weights.copy()
inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1))
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
if FLAGS.cclass:
label_d = tf.reshape(label, shape=(tf.shape(label)[0], 1, 1, self.label_size))
inp = conv_cond_concat(inp, label_d)
h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=False, extra_bias=False, activation=act)
h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])])
h6 = act(smart_fc_block(h5, weights, reuse, 'fc_dense'))
hidden6 = smart_fc_block(h6, weights, reuse, 'fc5')
return hidden6
class DspritesNet(object):
def __init__(self, num_channels=1, num_filters=64, cond_size=False, cond_shape=False, cond_pos=False,
cond_rot=False, label_size=1):
self.channels = num_channels
self.dim_hidden = num_filters
self.img_size = 64
self.label_size = label_size
if FLAGS.cclass:
self.label_size = 3
try:
if FLAGS.dshape_only:
self.label_size = 3
if FLAGS.dpos_only:
self.label_size = 2
if FLAGS.dsize_only:
self.label_size = 1
if FLAGS.drot_only:
self.label_size = 2
except:
pass
if cond_size:
self.label_size = 1
if cond_shape:
self.label_size = 3
if cond_pos:
self.label_size = 2
if cond_rot:
self.label_size = 2
self.cond_size = cond_size
self.cond_shape = cond_shape
self.cond_pos = cond_pos
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
k = 5
classes = self.label_size
with tf.variable_scope(scope):
init_conv_weight(weights, 'c1_pre', 3, 1, 32)
init_conv_weight(weights, 'c1', 4, 32, self.dim_hidden, classes=classes)
init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_conv_weight(weights, 'c4', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc_dense', 2*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, return_logit=False):
channels = self.channels
batch_size = tf.shape(inp)[0]
inp = tf.reshape(inp, (batch_size, 64, 64, 1))
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
if not FLAGS.cclass:
label = None
weights = weights.copy()
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=True, extra_bias=True, activation=act)
h5 = smart_conv_block(h4, weights, reuse, 'c4', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1))
hidden7 = act(smart_fc_block(hidden6, weights, reuse, 'fc_dense'))
energy = smart_fc_block(hidden7, weights, reuse, 'fc5')
if return_logit:
return hidden7
else:
return energy
class ResNet32(object):
def __init__(self, num_channels=3, num_filters=128):
self.channels = num_channels
self.dim_hidden = num_filters
self.groupsort = groupsort()
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
if FLAGS.cclass:
classes = 10
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_2', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_3', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
init_fc_weight(weights, 'fc5', 2*self.dim_hidden , 1, spec_norm=False)
init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
weights = weights.copy()
batch = tf.shape(inp)[0]
act = tf.nn.leaky_relu
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act)
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, act=act)
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, label=label, act=act)
if FLAGS.use_attention:
hidden4 = smart_atten_block(hidden3, weights, reuse, 'atten', stop_at_grad=stop_at_grad, label=label)
else:
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act)
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=False, label=label, act=act)
compact = hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden6 = tf.nn.relu(hidden6)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
energy = hidden6
return energy
class ResNet32Large(object):
def __init__(self, num_channels=3, num_filters=128, train=False):
self.channels = num_channels
self.dim_hidden = num_filters
self.dropout = train
self.train = train
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
if FLAGS.cclass:
classes = 10
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden, trainable_gamma=True)
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
weights = weights.copy()
batch = tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
dropout = self.dropout
train = self.train
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, dropout=dropout, train=train)
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
if FLAGS.use_attention:
hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
else:
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
compact = hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
if FLAGS.cclass:
hidden6 = tf.nn.leaky_relu(hidden9)
else:
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
energy = hidden6
return energy
class ResNet32Wider(object):
def __init__(self, num_channels=3, num_filters=128, train=False):
self.channels = num_channels
self.dim_hidden = num_filters
self.dropout = train
self.train = train
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
if FLAGS.cclass and FLAGS.dataset == "cifar10":
classes = 10
elif FLAGS.cclass and FLAGS.dataset == "imagenet":
classes = 1000
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, 'c1_pre', 3, self.channels, 128)
init_res_weight(weights, 'res_optim', 3, 128, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
weights = weights.copy()
batch = tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
dropout = self.dropout
train = self.train
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=True, label=label, dropout=dropout, train=train)
if FLAGS.use_attention:
hidden2 = smart_atten_block(hidden1, weights, reuse, 'atten', train=train, dropout=dropout, stop_at_grad=stop_at_grad)
else:
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
if FLAGS.swish_act:
hidden6 = act(hidden9)
else:
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
energy = hidden6
return energy
class ResNet32Larger(object):
def __init__(self, num_channels=3, num_filters=128):
self.channels = num_channels
self.dim_hidden = num_filters
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
if FLAGS.cclass:
classes = 10
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_2a', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_2b', 3, self.dim_hidden, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5a', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5b', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_8a', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_8b', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
weights = weights.copy()
batch = tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label)
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2a', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2b', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label)
if FLAGS.use_attention:
hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
else:
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label)
hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
compact = hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
if FLAGS.cclass:
hidden6 = tf.nn.leaky_relu(hidden9)
else:
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
energy = hidden6
return energy
class ResNet128(object):
"""Construct the convolutional network specified in MAML"""
def __init__(self, num_channels=3, num_filters=64, train=False):
self.channels = num_channels
self.dim_hidden = num_filters
self.dropout = train
self.train = train
def construct_weights(self, scope=''):
weights = {}
dtype = tf.float32
classes = 1000
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, 'c1_pre', 3, self.channels, 64)
init_res_weight(weights, 'res_optim', 3, 64, self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 8*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_9', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
init_res_weight(weights, 'res_10', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
init_fc_weight(weights, 'fc5', 8*self.dim_hidden , 1, spec_norm=False)
init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2., trainable_gamma=True)
return weights
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
weights = weights.copy()
batch = tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if type(v) == dict:
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
dropout = self.dropout
train = self.train
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False)
if FLAGS.use_attention:
hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True)
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False)
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False)
if FLAGS.swish_act:
hidden6 = act(hidden6)
else:
hidden6 = tf.nn.relu(hidden6)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
energy = hidden6
return energy

17
requirements.txt Normal file
View File

@ -0,0 +1,17 @@
scipy==1.1.0
horovod==0.16.0
torch==1.5.0
torchvision==0.6.0
six==1.11.0
imageio==2.8.0
tqdm==4.46.0
matplotlib==3.2.1
mpi4py==3.0.3
numpy==1.18.4
Pillow==5.4.1
baselines==0.1.5
scikit-image==0.14.2
scikit_learn
tensorflow==1.13.1
cloudpickle==1.3.0
Cython==0.29.17

333
test_inception.py Normal file
View File

@ -0,0 +1,333 @@
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
import os.path as osp
import os
from utils import optimistic_restore, remap_restore, optimistic_remap_restore
from tqdm import tqdm
import random
from scipy.misc import imsave
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
from torch.utils.data import DataLoader
from baselines.common.tf_util import initialize
import horovod.tensorflow as hvd
hvd.init()
from inception import get_inception_score
from fid import get_fid_score
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_bool('cclass', False, 'whether to condition on class')
# Architecture settings
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
flags.DEFINE_integer('batch_size', 512, 'batch size')
flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
flags.DEFINE_integer('idx', 0, 'save index')
flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
flags.DEFINE_bool('single', False, 'single ')
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
FLAGS = flags.FLAGS
class InceptionReplayBuffer(object):
def __init__(self, size):
"""Create Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self._storage = []
self._label_storage = []
self._maxsize = size
self._next_idx = 0
def __len__(self):
return len(self._storage)
def add(self, ims, labels):
batch_size = ims.shape[0]
if self._next_idx >= len(self._storage):
self._storage.extend(list(ims))
self._label_storage.extend(list(labels))
else:
if batch_size + self._next_idx < self._maxsize:
self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
else:
split_idx = self._maxsize - self._next_idx
self._storage[self._next_idx:] = list(ims)[:split_idx]
self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
self._label_storage[self._next_idx:] = list(labels)[:split_idx]
self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
def _encode_sample(self, idxes):
ims = []
labels = []
for i in idxes:
ims.append(self._storage[i])
labels.append(self._label_storage[i])
return np.array(ims), np.array(labels)
def sample(self, batch_size):
"""Sample a batch of experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
return self._encode_sample(idxes), idxes
def set_elms(self, idxes, data, labels):
for i, ix in enumerate(idxes):
self._storage[ix] = data[i]
self._label_storage[ix] = labels[i]
def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8)
def compute_inception(sess, target_vars):
X_START = target_vars['X_START']
Y_GT = target_vars['Y_GT']
X_finals = target_vars['X_finals']
NOISE_SCALE = target_vars['NOISE_SCALE']
energy_noise = target_vars['energy_noise']
size = FLAGS.im_number
num_steps = size // 1000
images = []
test_ims = []
if FLAGS.dataset == "cifar10":
test_dataset = Cifar10(full=True, noise=False)
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
test_dataset = Imagenet(train=False)
if FLAGS.dataset != "imagenetfull":
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
else:
test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data = data.numpy()
test_ims.extend(list(rescale_im(data)))
if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
test_ims = test_ims[:60000]
break
# n = min(len(images), len(test_ims))
print(len(test_ims))
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
# print("Base FID of score {}".format(fid))
if FLAGS.dataset == "cifar10":
classes = 10
else:
classes = 1000
if FLAGS.dataset == "imagenetfull":
n = 128
else:
n = 32
for j in range(num_steps):
itr = int(1000 / 500 * FLAGS.repeat_scale)
data_buffer = InceptionReplayBuffer(1000)
curr_index = 0
identity = np.eye(classes)
for i in tqdm(range(itr)):
model_index = curr_index % len(X_finals)
x_final = X_finals[model_index]
noise_scale = [1]
if len(data_buffer) < 1000:
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
label = np.random.randint(0, classes, (FLAGS.batch_size))
label = identity[label]
x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
data_buffer.add(x_new, label)
else:
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
label_corrupt = identity[label_corrupt]
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
if i < itr - FLAGS.nomix:
x_init[keep_mask] = x_init_corrupt[keep_mask]
label[label_keep_mask] = label_corrupt[label_keep_mask]
# else:
# noise_scale = [0.7]
x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
data_buffer.set_elms(idx, x_new, label)
if FLAGS.im_number != 50000:
print(np.mean(e_noise), np.std(e_noise))
curr_index += 1
ims = np.array(data_buffer._storage[:1000])
ims = rescale_im(ims)
images.extend(list(ims))
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
ims = ims[:100]
if FLAGS.dataset != "imagenetfull":
im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
else:
im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
imsave(saveim, im_panel)
print("Saved image!!!!")
splits = max(1, len(images) // 5000)
score, std = get_inception_score(images, splits=splits)
print("Inception score of {} with std of {}".format(score, std))
# FID score
# n = min(len(images), len(test_ims))
fid = get_fid_score(images, test_ims)
print("FID of score {}".format(fid))
def main(model_list):
if FLAGS.dataset == "imagenetfull":
model = ResNet128(num_filters=64)
elif FLAGS.large_model:
model = ResNet32Large(num_filters=128)
elif FLAGS.larger_model:
model = ResNet32Larger(num_filters=hidden_dim)
elif FLAGS.wider_model:
model = ResNet32Wider(num_filters=256, train=False)
else:
model = ResNet32(num_filters=128)
# config = tf.ConfigProto()
sess = tf.InteractiveSession()
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
weights = []
for i, model_num in enumerate(model_list):
weight = model.construct_weights('context_{}'.format(i))
initialize()
save_file = osp.join(logdir, 'model_{}'.format(model_num))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
saver = tf.train.Saver(v_map)
try:
saver.restore(sess, save_file)
except:
optimistic_remap_restore(sess, save_file, i)
weights.append(weight)
if FLAGS.dataset == "imagenetfull":
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
else:
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
if FLAGS.dataset == "cifar10":
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
else:
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
X_finals = []
# Seperate loops
for weight in weights:
X = X_START
steps = tf.constant(0)
c = lambda i, x: tf.less(i, FLAGS.num_steps)
def langevin_step(counter, X):
scale_rate = 1
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad * scale_rate
X = tf.clip_by_value(X, 0, 1)
counter = counter + 1
return counter, X
steps, X = tf.while_loop(c, langevin_step, (steps, X))
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
X_final = X
X_finals.append(X_final)
target_vars = {}
target_vars['X_START'] = X_START
target_vars['Y_GT'] = Y_GT
target_vars['X_finals'] = X_finals
target_vars['NOISE_SCALE'] = NOISE_SCALE
target_vars['energy_noise'] = energy_noise
compute_inception(sess, target_vars)
if __name__ == "__main__":
# model_list = [117000, 116700]
model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
main(model_list)

941
train.py Normal file
View File

@ -0,0 +1,941 @@
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128
import os.path as osp
import os
from baselines.logger import TensorBoardOutputFormat
from utils import average_gradients, ReplayBuffer, optimistic_restore
from tqdm import tqdm
import random
from torch.utils.data import DataLoader
import time as time
from io import StringIO
from tensorflow.core.util import event_pb2
import torch
import numpy as np
from custom_adam import AdamOptimizer
from scipy.misc import imsave
import matplotlib.pyplot as plt
from hmc import hmc
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
import horovod.tensorflow as hvd
hvd.init()
from inception import get_inception_score
torch.manual_seed(hvd.rank())
np.random.seed(hvd.rank())
tf.set_random_seed(hvd.rank())
FLAGS = flags.FLAGS
# Dataset Options
flags.DEFINE_string('datasource', 'random',
'initialization for chains, either random or default (decorruption)')
flags.DEFINE_string('dataset','mnist',
'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)')
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
flags.DEFINE_bool('single', False, 'whether to debug by training on a single image')
flags.DEFINE_integer('data_workers', 4,
'Number of different data workers to load data in parallel')
# General Experiment Settings
flags.DEFINE_string('logdir', 'cachedir',
'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches')
flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches')
flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('train', True, 'whether to train or test')
flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on')
flags.DEFINE_float('lr', 3e-4, 'Learning for training')
flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on')
# EBM Specific Experiments Settings
flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients')
flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training')
flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models')
flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels')
flags.DEFINE_integer('temperature', 1, 'Temperature for energy function')
flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),'
'logsumexp(more stable)'
'softplus(most stable)')
flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss')
# Setting for MCMC sampling
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection')
flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training')
flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent')
flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.')
flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models')
flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC')
flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead')
# Architecture Settings
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('large_model', False, 'whether to use a large model')
flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network')
flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network')
# Dataset settings
flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images')
flags.DEFINE_bool('augment', False, 'whether to augmentations to images')
flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box')
# Dsprites specific experiments
flags.DEFINE_bool('cond_shape', False, 'condition of shape type')
flags.DEFINE_bool('cond_size', False, 'condition of shape size')
flags.DEFINE_bool('cond_pos', False, 'condition of position loc')
flags.DEFINE_bool('cond_rot', False, 'condition of rot')
FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale
FLAGS.batch_size *= FLAGS.num_gpus
print("{} batch size".format(FLAGS.batch_size))
def compress_x_mod(x_mod):
x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8)
return x_mod
def decompress_x_mod(x_mod):
x_mod = x_mod / 256 * FLAGS.rescale + \
np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape)
return x_mod
def make_image(tensor):
"""Convert an numpy representation image to Image protobuf"""
from PIL import Image
if len(tensor.shape) == 4:
_, height, width, channel = tensor.shape
elif len(tensor.shape) == 3:
height, width, channel = tensor.shape
elif len(tensor.shape) == 2:
height, width = tensor.shape
channel = 1
tensor = tensor.astype(np.uint8)
image = Image.fromarray(tensor)
import io
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
return tf.Summary.Image(height=height,
width=width,
colorspace=channel,
encoded_image_string=image_string)
def log_image(im, logger, tag, step=0):
im = make_image(im)
summary = [tf.Summary.Value(tag=tag, image=im)]
summary = tf.Summary(value=summary)
event = event_pb2.Event(summary=summary)
event.step = step
logger.writer.WriteEvent(event)
logger.writer.Flush()
def rescale_im(image):
image = np.clip(image, 0, FLAGS.rescale)
if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites':
return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
else:
return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
X = target_vars['X']
Y = target_vars['Y']
X_NOISE = target_vars['X_NOISE']
train_op = target_vars['train_op']
energy_pos = target_vars['energy_pos']
energy_neg = target_vars['energy_neg']
loss_energy = target_vars['loss_energy']
loss_ml = target_vars['loss_ml']
loss_total = target_vars['total_loss']
gvs = target_vars['gvs']
x_grad = target_vars['x_grad']
x_grad_first = target_vars['x_grad_first']
x_off = target_vars['x_off']
temp = target_vars['temp']
x_mod = target_vars['x_mod']
LABEL = target_vars['LABEL']
LABEL_POS = target_vars['LABEL_POS']
weights = target_vars['weights']
test_x_mod = target_vars['test_x_mod']
eps = target_vars['eps_begin']
label_ent = target_vars['label_ent']
if FLAGS.use_attention:
gamma = weights[0]['atten']['gamma']
else:
gamma = tf.zeros(1)
val_output = [test_x_mod]
gvs_dict = dict(gvs)
log_output = [
train_op,
energy_pos,
energy_neg,
eps,
loss_energy,
loss_ml,
loss_total,
x_grad,
x_off,
x_mod,
gamma,
x_grad_first,
label_ent,
*gvs_dict.keys()]
output = [train_op, x_mod]
replay_buffer = ReplayBuffer(10000)
itr = resume_iter
x_mod = None
gd_steps = 1
dataloader_iterator = iter(dataloader)
best_inception = 0.0
for epoch in range(FLAGS.epoch_num):
for data_corrupt, data, label in dataloader:
data_corrupt = data_corrupt_init = data_corrupt.numpy()
data_corrupt_init = data_corrupt.copy()
data = data.numpy()
label = label.numpy()
label_init = label.copy()
if FLAGS.mixup:
idx = np.random.permutation(data.shape[0])
lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
data = data * lam + data[idx] * (1 - lam)
if FLAGS.replay_batch and (x_mod is not None):
replay_buffer.add(compress_x_mod(x_mod))
if len(replay_buffer) > FLAGS.batch_size:
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_batch = decompress_x_mod(replay_batch)
replay_mask = (
np.random.uniform(
0,
FLAGS.rescale,
FLAGS.batch_size) > 0.05)
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.pcd:
if x_mod is not None:
data_corrupt = x_mod
feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}
if FLAGS.cclass:
feed_dict[LABEL] = label
feed_dict[LABEL_POS] = label_init
if itr % FLAGS.log_interval == 0:
_, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
grads = sess.run(log_output, feed_dict)
kvs = {}
kvs['e_pos'] = e_pos.mean()
kvs['e_pos_std'] = e_pos.std()
kvs['e_neg'] = e_neg.mean()
kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
kvs['e_neg_std'] = e_neg.std()
kvs['temp'] = temp
kvs['loss_e'] = loss_e.mean()
kvs['eps'] = eps.mean()
kvs['label_ent'] = label_ent
kvs['loss_ml'] = loss_ml.mean()
kvs['loss_total'] = loss_total.mean()
kvs['x_grad'] = np.abs(x_grad).mean()
kvs['x_grad_first'] = np.abs(x_grad_first).mean()
kvs['x_off'] = x_off.mean()
kvs['iter'] = itr
kvs['gamma'] = gamma
for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
kvs[k] = np.abs(v).max()
string = "Obtained a total of "
for key, value in kvs.items():
string += "{}: {}, ".format(key, value)
if hvd.rank() == 0:
print(string)
logger.writekvs(kvs)
else:
_, x_mod = sess.run(output, feed_dict)
if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
saver.save(
sess,
osp.join(
FLAGS.logdir,
FLAGS.exp,
'model_{}'.format(itr)))
if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
try_im = x_mod
orig_im = data_corrupt.squeeze()
actual_im = rescale_im(data)
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im).squeeze()
for i, (im, t_im, actual_im_i) in enumerate(
zip(orig_im[:20], try_im[:20], actual_im)):
shape = orig_im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
new_im[:, size:2 * size] = t_im
new_im[:, 2 * size:] = actual_im_i
log_image(
new_im, logger, 'train_gen_{}'.format(itr), step=i)
test_im = x_mod
try:
data_corrupt, data, label = next(dataloader_iterator)
except BaseException:
dataloader_iterator = iter(dataloader)
data_corrupt, data, label = next(dataloader_iterator)
data_corrupt = data_corrupt.numpy()
if FLAGS.replay_batch and (
x_mod is not None) and len(replay_buffer) > 0:
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_batch = decompress_x_mod(replay_batch)
replay_mask = (
np.random.uniform(
0, 1, (FLAGS.batch_size)) > 0.05)
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
n = 128
if FLAGS.dataset == "imagenetfull":
n = 32
if len(replay_buffer) > n:
data_corrupt = decompress_x_mod(replay_buffer.sample(n))
elif FLAGS.dataset == 'imagenetfull':
data_corrupt = np.random.uniform(
0, FLAGS.rescale, (n, 128, 128, 3))
else:
data_corrupt = np.random.uniform(
0, FLAGS.rescale, (n, 32, 32, 3))
if FLAGS.dataset == 'cifar10':
label = np.eye(10)[np.random.randint(0, 10, (n))]
else:
label = np.eye(1000)[
np.random.randint(
0, 1000, (n))]
feed_dict[X_NOISE] = data_corrupt
feed_dict[X] = data
if FLAGS.cclass:
feed_dict[LABEL] = label
test_x_mod = sess.run(val_output, feed_dict)
try_im = test_x_mod
orig_im = data_corrupt.squeeze()
actual_im = rescale_im(data.numpy())
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im).squeeze()
for i, (im, t_im, actual_im_i) in enumerate(
zip(orig_im[:20], try_im[:20], actual_im)):
shape = orig_im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
new_im[:, size:2 * size] = t_im
new_im[:, 2 * size:] = actual_im_i
log_image(
new_im, logger, 'val_gen_{}'.format(itr), step=i)
score, std = get_inception_score(list(try_im), splits=1)
print(
"Inception score of {} with std of {}".format(
score, std))
kvs = {}
kvs['inception_score'] = score
kvs['inception_score_std'] = std
logger.writekvs(kvs)
if score > best_inception:
best_inception = score
saver.save(
sess,
osp.join(
FLAGS.logdir,
FLAGS.exp,
'model_best'))
if itr > 60000 and FLAGS.dataset == "mnist":
assert False
itr += 1
saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
cifar10_map = {0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'}
def test(target_vars, saver, sess, logger, dataloader):
X_NOISE = target_vars['X_NOISE']
X = target_vars['X']
Y = target_vars['Y']
LABEL = target_vars['LABEL']
energy_start = target_vars['energy_start']
x_mod = target_vars['x_mod']
x_mod = target_vars['test_x_mod']
energy_neg = target_vars['energy_neg']
np.random.seed(1)
random.seed(1)
output = [x_mod, energy_start, energy_neg]
dataloader_iterator = iter(dataloader)
data_corrupt, data, label = next(dataloader_iterator)
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
orig_im = try_im = data_corrupt
if FLAGS.cclass:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label})
else:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: orig_im, Y: label[0:1]})
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im)
actual_im = rescale_im(data)
for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate(
zip(orig_im, energy_orig, try_im, energy, label, actual_im)):
label_i = np.array(label_i)
shape = im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
new_im[:, size:2 * size] = t_im
if FLAGS.cclass:
label_i = np.where(label_i == 1)[0][0]
if FLAGS.dataset == 'cifar10':
log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format(
i, energy_i[0], energy[0], cifar10_map[label_i]), step=i)
else:
log_image(
new_im,
logger,
'{}_{:.4f}_now_{:.4f}_{}'.format(
i,
energy_i[0],
energy[0],
label_i),
step=i)
else:
log_image(
new_im,
logger,
'{}_{:.4f}_now_{:.4f}'.format(
i,
energy_i[0],
energy[0]),
step=i)
test_ims = list(try_im)
real_ims = list(actual_im)
for i in tqdm(range(50000 // FLAGS.batch_size + 1)):
try:
data_corrupt, data, label = dataloader_iterator.next()
except BaseException:
dataloader_iterator = iter(dataloader)
data_corrupt, data, label = dataloader_iterator.next()
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
if FLAGS.cclass:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label})
else:
try_im, energy_orig, energy = sess.run(
output, {X_NOISE: data_corrupt, Y: label[0:1]})
try_im = rescale_im(try_im)
real_im = rescale_im(data)
test_ims.extend(list(try_im))
real_ims.extend(list(real_im))
score, std = get_inception_score(test_ims)
print("Inception score of {} with std of {}".format(score, std))
def main():
print("Local rank: ", hvd.local_rank(), hvd.size())
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if hvd.rank() == 0:
if not osp.exists(logdir):
os.makedirs(logdir)
logger = TensorBoardOutputFormat(logdir)
else:
logger = None
LABEL = None
print("Loading data...")
if FLAGS.dataset == 'cifar10':
dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
if FLAGS.large_model:
model = ResNet32Large(
num_channels=channel_num,
num_filters=128,
train=True)
elif FLAGS.larger_model:
model = ResNet32Larger(
num_channels=channel_num,
num_filters=128)
elif FLAGS.wider_model:
model = ResNet32Wider(
num_channels=channel_num,
num_filters=192)
else:
model = ResNet32(
num_channels=channel_num,
num_filters=128)
elif FLAGS.dataset == 'imagenet':
dataset = Imagenet(train=True)
test_dataset = Imagenet(train=False)
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
model = ResNet32Wider(
num_channels=channel_num,
num_filters=256)
elif FLAGS.dataset == 'imagenetfull':
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
model = ResNet128(
num_channels=channel_num,
num_filters=64)
elif FLAGS.dataset == 'mnist':
dataset = Mnist(rescale=FLAGS.rescale)
test_dataset = dataset
channel_num = 1
X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
model = MnistNet(
num_channels=channel_num,
num_filters=FLAGS.num_filters)
elif FLAGS.dataset == 'dsprites':
dataset = DSprites(
cond_shape=FLAGS.cond_shape,
cond_size=FLAGS.cond_size,
cond_pos=FLAGS.cond_pos,
cond_rot=FLAGS.cond_rot)
test_dataset = dataset
channel_num = 1
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
if FLAGS.dpos_only:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
elif FLAGS.dsize_only:
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
elif FLAGS.drot_only:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
elif FLAGS.cond_size:
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
elif FLAGS.cond_shape:
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
elif FLAGS.cond_pos:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
elif FLAGS.cond_rot:
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
else:
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
model = DspritesNet(
num_channels=channel_num,
num_filters=FLAGS.num_filters,
cond_size=FLAGS.cond_size,
cond_shape=FLAGS.cond_shape,
cond_pos=FLAGS.cond_pos,
cond_rot=FLAGS.cond_rot)
print("Done loading...")
if FLAGS.dataset == "imagenetfull":
# In the case of full imagenet, use custom_tensorflow dataloader
data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale)
else:
data_loader = DataLoader(
dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.data_workers,
drop_last=True,
shuffle=True)
batch_size = FLAGS.batch_size
weights = [model.construct_weights('context_0')]
Y = tf.placeholder(shape=(None), dtype=tf.int32)
# Varibles to run in training
X_SPLIT = tf.split(X, FLAGS.num_gpus)
X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
LABEL_SPLIT_INIT = list(LABEL_SPLIT)
tower_grads = []
tower_gen_grads = []
x_mod_list = []
optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
optimizer = hvd.DistributedOptimizer(optimizer)
for j in range(FLAGS.num_gpus):
if FLAGS.model_cclass:
ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
label_tensor = tf.Variable(
tf.convert_to_tensor(
np.reshape(
np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
(FLAGS.batch_size * 10, 10)),
dtype=tf.float32),
trainable=False,
dtype=tf.float32)
x_split = tf.tile(
tf.reshape(
X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
energy_pos = model.forward(
x_split,
weights[0],
label=label_tensor,
stop_at_grad=False)
energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
energy_partition_est = tf.reduce_logsumexp(
energy_pos_full, axis=1, keepdims=True)
uniform = tf.random_uniform(tf.shape(energy_pos_full))
label_tensor = tf.argmax(-energy_pos_full -
tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
label = tf.Print(label, [label_tensor, energy_pos_full])
LABEL_SPLIT[j] = label
energy_pos = tf.concat(energy_pos, axis=0)
else:
energy_pos = [
model.forward(
X_SPLIT[j],
weights[0],
label=LABEL_POS_SPLIT[j],
stop_at_grad=False)]
energy_pos = tf.concat(energy_pos, axis=0)
print("Building graph...")
x_mod = x_orig = X_NOISE_SPLIT[j]
x_grads = []
energy_negs = []
loss_energys = []
energy_negs.extend([model.forward(tf.stop_gradient(
x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
eps_begin = tf.zeros(1)
steps = tf.constant(0)
c = lambda i, x: tf.less(i, FLAGS.num_steps)
def langevin_step(counter, x_mod):
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
mean=0.0,
stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)
energy_noise = energy_start = tf.concat(
[model.forward(
x_mod,
weights[0],
label=LABEL_SPLIT[j],
reuse=True,
stop_at_grad=False,
stop_batch=True)],
axis=0)
x_grad, label_grad = tf.gradients(
FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
energy_noise_old = energy_noise
lr = FLAGS.step_lr
if FLAGS.proj_norm != 0.0:
if FLAGS.proj_norm_type == 'l2':
x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
elif FLAGS.proj_norm_type == 'li':
x_grad = tf.clip_by_value(
x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
else:
print("Other types of projection are not supported!!!")
assert False
# Clip gradient norm for now
if FLAGS.hmc:
# Step size should be tuned to get around 65% acceptance
def energy(x):
return FLAGS.temperature * \
model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)
x_last = hmc(x_mod, 15., 10, energy)
else:
x_last = x_mod - (lr) * x_grad
x_mod = x_last
x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)
counter = counter + 1
return counter, x_mod
steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))
energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
stop_at_grad=False, reuse=True)
x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
x_grads.append(x_grad)
energy_negs.append(
model.forward(
tf.stop_gradient(x_mod),
weights[0],
label=LABEL_SPLIT[j],
stop_at_grad=False,
reuse=True))
test_x_mod = x_mod
temp = FLAGS.temperature
energy_neg = energy_negs[-1]
x_off = tf.reduce_mean(
tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))
loss_energy = model.forward(
x_mod,
weights[0],
reuse=True,
label=LABEL,
stop_grad=True)
print("Finished processing loop construction ...")
target_vars = {}
if FLAGS.cclass or FLAGS.model_cclass:
label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
label_prob = label_sum / tf.reduce_sum(label_sum)
label_ent = -tf.reduce_sum(label_prob *
tf.math.log(label_prob + 1e-7))
else:
label_ent = tf.zeros(1)
target_vars['label_ent'] = label_ent
if FLAGS.train:
if FLAGS.objective == 'logsumexp':
pos_term = temp * energy_pos
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
pos_loss = tf.reduce_mean(temp * energy_pos)
neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
elif FLAGS.objective == 'cd':
pos_loss = tf.reduce_mean(temp * energy_pos)
neg_loss = -tf.reduce_mean(temp * energy_neg)
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
elif FLAGS.objective == 'softplus':
loss_ml = FLAGS.ml_coeff * \
tf.nn.softplus(temp * (energy_pos - energy_neg))
loss_total = tf.reduce_mean(loss_ml)
if not FLAGS.zero_kl:
loss_total = loss_total + tf.reduce_mean(loss_energy)
loss_total = loss_total + \
FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))
print("Started gradient computation...")
gvs = optimizer.compute_gradients(loss_total)
gvs = [(k, v) for (k, v) in gvs if k is not None]
print("Applying gradients...")
tower_grads.append(gvs)
print("Finished applying gradients.")
target_vars['loss_ml'] = loss_ml
target_vars['total_loss'] = loss_total
target_vars['loss_energy'] = loss_energy
target_vars['weights'] = weights
target_vars['gvs'] = gvs
target_vars['X'] = X
target_vars['Y'] = Y
target_vars['LABEL'] = LABEL
target_vars['LABEL_POS'] = LABEL_POS
target_vars['X_NOISE'] = X_NOISE
target_vars['energy_pos'] = energy_pos
target_vars['energy_start'] = energy_negs[0]
if len(x_grads) >= 1:
target_vars['x_grad'] = x_grads[-1]
target_vars['x_grad_first'] = x_grads[0]
else:
target_vars['x_grad'] = tf.zeros(1)
target_vars['x_grad_first'] = tf.zeros(1)
target_vars['x_mod'] = x_mod
target_vars['x_off'] = x_off
target_vars['temp'] = temp
target_vars['energy_neg'] = energy_neg
target_vars['test_x_mod'] = test_x_mod
target_vars['eps_begin'] = eps_begin
if FLAGS.train:
grads = average_gradients(tower_grads)
train_op = optimizer.apply_gradients(grads)
target_vars['train_op'] = train_op
config = tf.ConfigProto()
if hvd.size() > 1:
config.gpu_options.visible_device_list = str(hvd.local_rank())
sess = tf.Session(config=config)
saver = loader = tf.train.Saver(
max_to_keep=30, keep_checkpoint_every_n_hours=6)
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Model has a total of {} parameters".format(total_parameters))
sess.run(tf.global_variables_initializer())
resume_itr = 0
if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
# saver.restore(sess, model_file)
optimistic_restore(sess, model_file)
sess.run(hvd.broadcast_global_variables(0))
print("Initializing variables...")
print("Start broadcast")
print("End broadcast")
if FLAGS.train:
train(target_vars, saver, sess,
logger, data_loader, resume_itr,
logdir)
test(target_vars, saver, sess, logger, data_loader)
if __name__ == "__main__":
main()

1107
utils.py Normal file

File diff suppressed because it is too large Load Diff