mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-05-07 17:14:56 -04:00
Clean up to start add modern models (#24)
This commit is contained in:
parent
94d09f6fba
commit
3f8821f1d4
34 changed files with 845 additions and 309 deletions
253
EBMs/README.md
Normal file
253
EBMs/README.md
Normal file
|
@ -0,0 +1,253 @@
|
|||
## quantum ai: training energy-based-models using openai
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
|
||||
|
||||
<br>
|
||||
|
||||
### installing
|
||||
|
||||
<br>
|
||||
|
||||
```bash
|
||||
brew install gcc@6
|
||||
brew install open-mpi
|
||||
brew install pkg-config
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
export PMIX_MCA_gds=^ds12
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
* then install python's requirements:
|
||||
|
||||
<br>
|
||||
|
||||
```bash
|
||||
virtualenv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
<br>
|
||||
|
||||
* note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
|
||||
* finally, download and install **[mujoco](https://www.roboti.us/index.html)**
|
||||
* you will also need to register for a license, which asks for a machine ID
|
||||
* the documentation on the website is incomplete, so just download the suggested script and run:
|
||||
|
||||
<br>
|
||||
|
||||
```bash
|
||||
mv getid_osx getid_osx.dms
|
||||
./getid_osx.dms
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
### running
|
||||
|
||||
<br>
|
||||
|
||||
#### download pre-trained models (examples)
|
||||
|
||||
<br>
|
||||
|
||||
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`:
|
||||
|
||||
<br>
|
||||
|
||||
```bash
|
||||
mkdir cachedir
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
#### setting results directory
|
||||
|
||||
<br>
|
||||
|
||||
* openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
|
||||
* i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
|
||||
|
||||
<br>
|
||||
|
||||
#### running (parallelization with `mpiexec`)
|
||||
|
||||
<br>
|
||||
|
||||
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
mpiexec -n <worker_num> <command>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### cifar-10 unconditional
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
|
||||
```
|
||||
|
||||
* this should generate the following output:
|
||||
|
||||
<br>
|
||||
|
||||
```bash
|
||||
Instructions for updating:
|
||||
Use tf.gfile.GFile.
|
||||
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
||||
64 batch size
|
||||
Local rank: 0 1
|
||||
Loading data...
|
||||
Files already downloaded and verified
|
||||
Files already downloaded and verified
|
||||
Files already downloaded and verified
|
||||
Files already downloaded and verified
|
||||
Done loading...
|
||||
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
|
||||
Instructions for updating:
|
||||
Colocations handled automatically by placer.
|
||||
Building graph...
|
||||
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
|
||||
Instructions for updating:
|
||||
Use tf.cast instead.
|
||||
Finished processing loop construction ...
|
||||
Started gradient computation...
|
||||
Applying gradients...
|
||||
Finished applying gradients.
|
||||
Model has a total of 7567880 parameters
|
||||
Initializing variables...
|
||||
Start broadcast
|
||||
End broadcast
|
||||
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l
|
||||
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/
|
||||
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268
|
||||
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0
|
||||
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2
|
||||
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908
|
||||
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463
|
||||
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g:
|
||||
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5
|
||||
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/
|
||||
b:0: 0.4506262540817261,
|
||||
|
||||
................................................................................................................................
|
||||
Inception score of 1.2397289276123047 with std of 0.0
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### cifar-10 conditional
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### imagenet 32x32 conditional
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### imagenet 128x128 conditional
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### imagenet demo
|
||||
|
||||
<br>
|
||||
|
||||
* the imagenet_demo.py file contains code for experiments with ebms on conditional imagenet 128x128
|
||||
* to generate a gif on sampling, you can run the command:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
|
||||
```
|
||||
|
||||
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file
|
||||
* for example, to visualize cross class mappings in cifar-10, you can run:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resume_iter=74700
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### generalization
|
||||
|
||||
<br>
|
||||
|
||||
* to test generalization to out of distribution classification for SVHN (with similar commands for other datasets):
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
* to test classification on cifar-10 using a conditional model under either L2 or Li perturbations
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd steps> --num_steps=10 --lival=<li bound value> --wider_model
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
##### concept combination
|
||||
|
||||
<br>
|
||||
|
||||
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
* once models are trained, they can be sampled from jointly by running:
|
||||
|
||||
```
|
||||
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos> --exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot> --resume_pos=<resume_pos>
|
||||
```
|
249
EBMs/ais.py
Normal file
249
EBMs/ais.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
from hmc import hmc
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
|
||||
from data import Cifar10, Mnist, DSprites
|
||||
from scipy.misc import logsumexp
|
||||
from scipy.misc import imsave
|
||||
from utils import optimistic_restore
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
||||
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
|
||||
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
|
||||
|
||||
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
|
||||
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
|
||||
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
|
||||
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
label_default = np.eye(10)[0:1, :]
|
||||
label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
||||
|
||||
|
||||
def unscale_im(im):
|
||||
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
||||
|
||||
def gauss_prob_log(x, prec=1.0):
|
||||
|
||||
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
||||
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
|
||||
|
||||
return norm_constant_log + prob_density_log
|
||||
|
||||
|
||||
def uniform_prob_log(x):
|
||||
|
||||
return tf.zeros(1)
|
||||
|
||||
|
||||
def model_prob_log(x, e_func, weights, temp):
|
||||
if FLAGS.cclass:
|
||||
batch_size = tf.shape(x)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_raw = e_func.forward(x, weights, label=label_tiled)
|
||||
else:
|
||||
e_raw = e_func.forward(x, weights)
|
||||
energy = tf.reduce_sum(e_raw, axis=[1])
|
||||
return -temp * energy
|
||||
|
||||
|
||||
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
||||
|
||||
if FLAGS.dataset == "gauss":
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
|
||||
else:
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
|
||||
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
|
||||
|
||||
|
||||
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
|
||||
else:
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
|
||||
|
||||
return -norm_prob + oob_prob
|
||||
|
||||
|
||||
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
|
||||
if FLAGS.dataset == "2d":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 2))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim))
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 28, 28))
|
||||
else:
|
||||
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
|
||||
|
||||
x_init = x
|
||||
|
||||
alpha_prev = tf.placeholder(tf.float32, shape=())
|
||||
alpha_new = tf.placeholder(tf.float32, shape=())
|
||||
approx_lr = tf.placeholder(tf.float32, shape=())
|
||||
|
||||
chain_weights = tf.zeros(batch_size)
|
||||
# for i in range(1, prop_dist+1):
|
||||
# print("processing loop {}".format(i))
|
||||
# alpha_prev = (i-1) / prop_dist
|
||||
# alpha_new = i / prop_dist
|
||||
|
||||
prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp)
|
||||
prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
||||
|
||||
chain_weights = -prob_log_new_neg + prob_log_old_neg
|
||||
# chain_weights = tf.Print(chain_weights, [chain_weights])
|
||||
|
||||
# Sample new x using HMC
|
||||
def unorm_prob(x):
|
||||
return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
|
||||
|
||||
for j in range(1):
|
||||
x = hmc(x, approx_lr, hmc_step, unorm_prob)
|
||||
|
||||
return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Initialize dataset
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 3
|
||||
dim_input = 32 * 32 * 3
|
||||
elif FLAGS.dataset == 'imagenet':
|
||||
dataset = ImagenetClass()
|
||||
channel_num = 3
|
||||
dim_input = 64 * 64 * 3
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 1
|
||||
dim_input = 28 * 28 * 1
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
dataset = DSprites()
|
||||
channel_num = 1
|
||||
dim_input = 64 * 64 * 1
|
||||
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
dataset = Box2D()
|
||||
|
||||
dim_output = 1
|
||||
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
|
||||
|
||||
if FLAGS.dataset == 'mnist':
|
||||
model = MnistNet(num_channels=channel_num)
|
||||
elif FLAGS.dataset == 'cifar10':
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=128)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(num_filters=192)
|
||||
else:
|
||||
model = ResNet32(num_channels=channel_num, num_filters=128)
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
||||
|
||||
weights = model.construct_weights('context_{}'.format(0))
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.Session(config=config)
|
||||
saver = loader = tf.train.Saver(max_to_keep=10)
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
|
||||
if FLAGS.resume_iter != "-1":
|
||||
optimistic_restore(sess, model_file)
|
||||
else:
|
||||
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
||||
# saver.restore(sess, model_file)
|
||||
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
|
||||
print("Finished constructing ancestral sample ...................")
|
||||
|
||||
if FLAGS.dataset != "gauss":
|
||||
comb_weights_cum = []
|
||||
batch_size = tf.shape(x_init)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
|
||||
e_pos_list = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(data_loader):
|
||||
e_pos = sess.run([e_compute], {x_init: data})[0]
|
||||
e_pos_list.extend(list(e_pos))
|
||||
|
||||
print(len(e_pos_list))
|
||||
print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list))
|
||||
|
||||
if FLAGS.dataset == "2d":
|
||||
alr = 0.0045
|
||||
elif FLAGS.dataset == "gauss":
|
||||
alr = 0.0085
|
||||
elif FLAGS.dataset == "mnist":
|
||||
alr = 0.0065
|
||||
#90 alr = 0.0035
|
||||
else:
|
||||
# alr = 0.0125
|
||||
if FLAGS.rescale == 8:
|
||||
alr = 0.0085
|
||||
else:
|
||||
alr = 0.0045
|
||||
#
|
||||
for i in range(1):
|
||||
tot_weight = 0
|
||||
for j in tqdm(range(1, FLAGS.pdist+1)):
|
||||
if j == 1:
|
||||
if FLAGS.dataset == "cifar10":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
|
||||
else:
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
|
||||
|
||||
alpha_prev = (j-1) / FLAGS.pdist
|
||||
alpha_new = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
tot_weight = tot_weight + cweight
|
||||
|
||||
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
tot_weight = 0
|
||||
|
||||
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
||||
alpha_new = (j-1) / FLAGS.pdist
|
||||
alpha_prev = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
tot_weight = tot_weight - cweight
|
||||
|
||||
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
236
EBMs/custom_adam.py
Normal file
236
EBMs/custom_adam.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Adam for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@tf_export("train.AdamOptimizer")
|
||||
class AdamOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
|
||||
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
|
||||
use_locking=False, name="Adam"):
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
Initialization:
|
||||
|
||||
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||
$$t := 0 \text{(Initialize timestep)}$$
|
||||
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section2 of the paper:
|
||||
|
||||
$$t := t + 1$$
|
||||
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
|
||||
The default value of 1e-8 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(AdamOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._beta1 = beta1
|
||||
self._beta2 = beta2
|
||||
self._epsilon = epsilon
|
||||
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._beta1_t = None
|
||||
self._beta2_t = None
|
||||
self._epsilon_t = None
|
||||
|
||||
# Created in SparseApply if needed.
|
||||
self._updated_lr = None
|
||||
|
||||
def _get_beta_accumulators(self):
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
graph = None
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
return (self._get_non_slot_variable("beta1_power", graph=graph),
|
||||
self._get_non_slot_variable("beta2_power", graph=graph))
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
# Create the beta1 and beta2 accumulators on the same device as the first
|
||||
# variable. Sort the var_list to make sure this device is consistent across
|
||||
# workers (these need to go on the same PS, otherwise some updates are
|
||||
# silently ignored).
|
||||
first_var = min(var_list, key=lambda x: x.name)
|
||||
self._create_non_slot_variable(initial_value=self._beta1,
|
||||
name="beta1_power",
|
||||
colocate_with=first_var)
|
||||
self._create_non_slot_variable(initial_value=self._beta2,
|
||||
name="beta2_power",
|
||||
colocate_with=first_var)
|
||||
|
||||
# Create slots for the first and second moments.
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "m", self._name)
|
||||
self._zeros_slot(v, "v", self._name)
|
||||
|
||||
def _prepare(self):
|
||||
lr = self._call_if_callable(self._lr)
|
||||
beta1 = self._call_if_callable(self._beta1)
|
||||
beta2 = self._call_if_callable(self._beta2)
|
||||
epsilon = self._call_if_callable(self._epsilon)
|
||||
|
||||
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
|
||||
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
|
||||
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
|
||||
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
|
||||
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
|
||||
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
|
||||
# Clip gradients by 3 std
|
||||
return training_ops.apply_adam(
|
||||
var, m, v,
|
||||
math_ops.cast(beta1_power, var.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, var.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking).op
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.resource_apply_adam(
|
||||
var.handle, m.handle, v.handle,
|
||||
math_ops.cast(beta1_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking)
|
||||
|
||||
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
||||
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||
m = self.get_slot(var, "m")
|
||||
m_scaled_g_values = grad * (1 - beta1_t)
|
||||
m_t = state_ops.assign(m, m * beta1_t,
|
||||
use_locking=self._use_locking)
|
||||
with ops.control_dependencies([m_t]):
|
||||
m_t = scatter_add(m, indices, m_scaled_g_values)
|
||||
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||
v = self.get_slot(var, "v")
|
||||
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
||||
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([v_t]):
|
||||
v_t = scatter_add(v, indices, v_scaled_g_values)
|
||||
v_sqrt = math_ops.sqrt(v_t)
|
||||
var_update = state_ops.assign_sub(var,
|
||||
lr * m_t / (v_sqrt + epsilon_t),
|
||||
use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*[var_update, m_t, v_t])
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
return self._apply_sparse_shared(
|
||||
grad.values, var, grad.indices,
|
||||
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
|
||||
x, i, v, use_locking=self._use_locking))
|
||||
|
||||
def _resource_scatter_add(self, x, i, v):
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(
|
||||
x.handle, i, v)]):
|
||||
return x.value()
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
return self._apply_sparse_shared(
|
||||
grad, var, indices, self._resource_scatter_add)
|
||||
|
||||
def _finish(self, update_ops, name_scope):
|
||||
# Update the power accumulators.
|
||||
with ops.control_dependencies(update_ops):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
with ops.colocate_with(beta1_power):
|
||||
update_beta1 = beta1_power.assign(
|
||||
beta1_power * self._beta1_t, use_locking=self._use_locking)
|
||||
update_beta2 = beta2_power.assign(
|
||||
beta2_power * self._beta2_t, use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
|
||||
name=name_scope)
|
546
EBMs/data.py
Normal file
546
EBMs/data.py
Normal file
|
@ -0,0 +1,546 @@
|
|||
from tensorflow.python.platform import flags
|
||||
from tensorflow.contrib.data.python.ops import batching, threadpool
|
||||
import tensorflow as tf
|
||||
import json
|
||||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
from scipy.misc import imread, imresize
|
||||
from skimage.color import rgb2grey
|
||||
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
|
||||
from torchvision import transforms
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DIR = "./results"
|
||||
|
||||
# Dataset Options
|
||||
flags.DEFINE_string('dsprites_path',
|
||||
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
|
||||
'path to dsprites characters')
|
||||
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
|
||||
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
|
||||
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
|
||||
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
|
||||
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
|
||||
|
||||
|
||||
# Data augmentation options
|
||||
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
|
||||
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
|
||||
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
|
||||
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
|
||||
|
||||
|
||||
def cutout(mask_color=(0, 0, 0)):
|
||||
mask_size_half = FLAGS.cutout_mask_size // 2
|
||||
offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0
|
||||
|
||||
def _cutout(image):
|
||||
image = np.asarray(image).copy()
|
||||
|
||||
if np.random.random() > FLAGS.cutout_prob:
|
||||
return image
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if FLAGS.cutout_inside:
|
||||
cxmin, cxmax = mask_size_half, w + offset - mask_size_half
|
||||
cymin, cymax = mask_size_half, h + offset - mask_size_half
|
||||
else:
|
||||
cxmin, cxmax = 0, w + offset
|
||||
cymin, cymax = 0, h + offset
|
||||
|
||||
cx = np.random.randint(cxmin, cxmax)
|
||||
cy = np.random.randint(cymin, cymax)
|
||||
xmin = cx - mask_size_half
|
||||
ymin = cy - mask_size_half
|
||||
xmax = xmin + FLAGS.cutout_mask_size
|
||||
ymax = ymin + FLAGS.cutout_mask_size
|
||||
xmin = max(0, xmin)
|
||||
ymin = max(0, ymin)
|
||||
xmax = min(w, xmax)
|
||||
ymax = min(h, ymax)
|
||||
image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
|
||||
return image
|
||||
|
||||
return _cutout
|
||||
|
||||
|
||||
class TFImagenetLoader(Dataset):
|
||||
|
||||
def __init__(self, split, batchsize, idx, num_workers, rescale=1):
|
||||
IMAGENET_NUM_TRAIN_IMAGES = 1281167
|
||||
IMAGENET_NUM_VAL_IMAGES = 50000
|
||||
|
||||
self.rescale = rescale
|
||||
|
||||
if split == "train":
|
||||
im_length = IMAGENET_NUM_TRAIN_IMAGES
|
||||
records_to_skip = im_length * idx // num_workers
|
||||
records_to_read = im_length * (idx + 1) // num_workers - records_to_skip
|
||||
else:
|
||||
im_length = IMAGENET_NUM_VAL_IMAGES
|
||||
|
||||
self.curr_sample = 0
|
||||
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
|
||||
with open(index_path) as f:
|
||||
metadata = json.load(f)
|
||||
counts = metadata['record_counts']
|
||||
|
||||
if split == 'train':
|
||||
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
|
||||
|
||||
result_records_to_skip = None
|
||||
files = []
|
||||
for filename in file_names:
|
||||
records_in_file = counts[filename]
|
||||
if records_to_skip >= records_in_file:
|
||||
records_to_skip -= records_in_file
|
||||
continue
|
||||
elif records_to_read > 0:
|
||||
if result_records_to_skip is None:
|
||||
# Record the number to skip in the first file
|
||||
result_records_to_skip = records_to_skip
|
||||
files.append(filename)
|
||||
records_to_read -= (records_in_file - records_to_skip)
|
||||
records_to_skip = 0
|
||||
else:
|
||||
break
|
||||
else:
|
||||
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
|
||||
|
||||
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
|
||||
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
|
||||
|
||||
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
|
||||
ds = ds.apply(tf.data.TFRecordDataset)
|
||||
ds = ds.take(im_length)
|
||||
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
|
||||
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
|
||||
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
|
||||
ds = ds.prefetch(buffer_size=2)
|
||||
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
labels, images = ds_iterator.get_next()
|
||||
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
|
||||
self.labels = labels
|
||||
|
||||
config = tf.ConfigProto(device_count = {'GPU': 0})
|
||||
sess = tf.Session(config=config)
|
||||
sess.run(ds_iterator.initializer)
|
||||
|
||||
self.im_length = im_length // batchsize
|
||||
|
||||
self.sess = sess
|
||||
|
||||
def __next__(self):
|
||||
self.curr_sample += 1
|
||||
|
||||
sess = self.sess
|
||||
|
||||
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
label, im = sess.run([self.labels, self.images])
|
||||
im = im * self.rescale
|
||||
label = np.eye(1000)[label.squeeze() - 1]
|
||||
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
|
||||
return im_corrupt, im, label
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return self.im_length
|
||||
|
||||
class CelebA(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
self.path = "/root/data/img_align_celeba"
|
||||
self.ims = os.listdir(self.path)
|
||||
self.ims = [osp.join(self.path, im) for im in self.ims]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ims)
|
||||
|
||||
def __getitem__(self, index):
|
||||
label = 1
|
||||
|
||||
if FLAGS.single:
|
||||
index = 0
|
||||
|
||||
path = self.ims[index]
|
||||
im = imread(path)
|
||||
im = imresize(im, (32, 32))
|
||||
image_size = 32
|
||||
im = im / 255.
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0, 1, size=(image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar10(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
train=True,
|
||||
full=False,
|
||||
augment=False,
|
||||
noise=True,
|
||||
rescale=1.0):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
]
|
||||
|
||||
if FLAGS.cutout:
|
||||
transform_list.append(cutout())
|
||||
|
||||
transform = transforms.Compose(transform_list)
|
||||
else:
|
||||
transform = transforms.ToTensor()
|
||||
|
||||
self.full = full
|
||||
self.data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.test_data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=False,
|
||||
download=True)
|
||||
self.one_hot_map = np.eye(10)
|
||||
self.noise = noise
|
||||
self.rescale = rescale
|
||||
|
||||
def __len__(self):
|
||||
|
||||
if self.full:
|
||||
return len(self.data) + len(self.test_data)
|
||||
else:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
if self.full:
|
||||
if index >= len(self.data):
|
||||
im, label = self.test_data[index - len(self.data)]
|
||||
else:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
im, label = self.data[0]
|
||||
|
||||
im = np.transpose(im, (1, 2, 0)).numpy()
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
|
||||
im = im * 255 / 256
|
||||
|
||||
if self.noise:
|
||||
im = im * self.rescale + \
|
||||
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
|
||||
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, self.rescale, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar100(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
]
|
||||
|
||||
if FLAGS.cutout:
|
||||
transform_list.append(cutout())
|
||||
|
||||
transform = transforms.Compose(transform_list)
|
||||
else:
|
||||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = CIFAR100(
|
||||
"/root/cifar100",
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.one_hot_map = np.eye(100)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
im, label = self.data[0]
|
||||
|
||||
im = np.transpose(im, (1, 2, 0)).numpy()
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Svhn(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
|
||||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = SVHN("/root/svhn", transform=transform, download=True)
|
||||
self.one_hot_map = np.eye(10)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
em, label = self.data[0]
|
||||
|
||||
im = np.transpose(im, (1, 2, 0)).numpy()
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Mnist(Dataset):
|
||||
def __init__(self, train=True, rescale=1.0):
|
||||
self.data = MNIST(
|
||||
"/root/mnist",
|
||||
transform=transforms.ToTensor(),
|
||||
download=True, train=train)
|
||||
self.labels = np.eye(10)
|
||||
self.rescale = rescale
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
im, label = self.data[index]
|
||||
label = self.labels[label]
|
||||
im = im.squeeze()
|
||||
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
|
||||
# im = im.numpy() / 2 + 0.2
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
|
||||
im = im * self.rescale
|
||||
image_size = 28
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class DSprites(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cond_size=False,
|
||||
cond_shape=False,
|
||||
cond_pos=False,
|
||||
cond_rot=False):
|
||||
dat = np.load(FLAGS.dsprites_path)
|
||||
|
||||
if FLAGS.dshape_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
|
||||
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = self.label[:, 1:2]
|
||||
elif FLAGS.dpos_only:
|
||||
l = dat['latents_values']
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 1] == 1) & (
|
||||
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = self.label[:, 4:] + 0.5
|
||||
elif FLAGS.dsize_only:
|
||||
l = dat['latents_values']
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = (self.label[:, 2:3])
|
||||
elif FLAGS.drot_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = (self.label[:, 3:4])
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
elif FLAGS.dsprites_restrict:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
|
||||
|
||||
self.data = dat['imgs'][mask]
|
||||
self.label = dat['latents_values'][mask]
|
||||
else:
|
||||
self.data = dat['imgs']
|
||||
self.label = dat['latents_values']
|
||||
|
||||
if cond_size:
|
||||
self.label = self.label[:, 2:3]
|
||||
elif cond_shape:
|
||||
self.label = self.label[:, 1:2]
|
||||
elif cond_pos:
|
||||
self.label = self.label[:, 4:]
|
||||
elif cond_rot:
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
else:
|
||||
self.label = self.label[:, 1:2]
|
||||
|
||||
self.identity = np.eye(3)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
im = self.data[index]
|
||||
image_size = 64
|
||||
|
||||
if not (
|
||||
FLAGS.dpos_only or FLAGS.dsize_only) and (
|
||||
not FLAGS.cond_size) and (
|
||||
not FLAGS.cond_pos) and (
|
||||
not FLAGS.cond_rot) and (
|
||||
not FLAGS.drot_only):
|
||||
label = self.identity[self.label[index].astype(
|
||||
np.int32) - 1].squeeze()
|
||||
else:
|
||||
label = self.label[index]
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Imagenet(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
|
||||
if train:
|
||||
for i in range(1, 11):
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'train_data_batch_{}'.format(i)),
|
||||
'rb'))
|
||||
if i == 1:
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
else:
|
||||
labels.extend(f['labels'])
|
||||
data = np.vstack((data, f['data']))
|
||||
else:
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'val_data'),
|
||||
'rb'))
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
|
||||
self.labels = labels
|
||||
self.data = data
|
||||
self.one_hot_map = np.eye(1000)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
im, label = self.data[index], self.labels[index]
|
||||
else:
|
||||
im, label = self.data[0], self.labels[0]
|
||||
|
||||
label -= 1
|
||||
|
||||
im = im.reshape((3, 32, 32)) / 255
|
||||
im = im.transpose((1, 2, 0))
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Textures(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images")
|
||||
|
||||
def __len__(self):
|
||||
return 2 * len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
idx = index % (len(self.dataset))
|
||||
im, label = self.dataset[idx]
|
||||
|
||||
im = np.array(im)[:32, :32] / 255
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
|
||||
return im, im, label
|
698
EBMs/ebm_combine.py
Normal file
698
EBMs/ebm_combine.py
Normal file
|
@ -0,0 +1,698 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from hmc import hmc
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from models import DspritesNet
|
||||
from utils import optimistic_restore, ReplayBuffer
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from rl_algs.logger import TensorBoardOutputFormat
|
||||
from scipy.misc import imsave
|
||||
import os
|
||||
from custom_adam import AdamOptimizer
|
||||
|
||||
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
|
||||
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
|
||||
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
|
||||
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
|
||||
flags.DEFINE_bool('cclass', True, 'not cclass')
|
||||
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
|
||||
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
|
||||
|
||||
# Conditions on which models to use
|
||||
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
|
||||
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
|
||||
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
|
||||
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
|
||||
|
||||
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
|
||||
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
|
||||
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
|
||||
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
|
||||
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('break_steps', 300, 'steps to break')
|
||||
|
||||
# Whether to train for gentest
|
||||
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
class DSpritesGen(Dataset):
|
||||
def __init__(self, data, latents, frac=0.0):
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
elif FLAGS.joint_rot:
|
||||
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
|
||||
|
||||
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
|
||||
data_pos = data[mask_pos]
|
||||
l_pos = l[mask_pos]
|
||||
|
||||
data_size = data[mask_size]
|
||||
l_size = l[mask_size]
|
||||
|
||||
n = data_pos.shape[0] // data_size.shape[0]
|
||||
|
||||
data_pos = np.tile(data_pos, (n, 1, 1))
|
||||
l_pos = np.tile(l_pos, (n, 1))
|
||||
|
||||
self.data = np.concatenate((data_pos, data_size), axis=0)
|
||||
self.label = np.concatenate((l_pos, l_size), axis=0)
|
||||
|
||||
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
|
||||
data_add = data[mask_neg]
|
||||
l_add = l[mask_neg]
|
||||
|
||||
perm_idx = np.random.permutation(data_add.shape[0])
|
||||
select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
|
||||
data_add = data_add[select_idx]
|
||||
l_add = l_add[select_idx]
|
||||
|
||||
self.data = np.concatenate((self.data, data_add), axis=0)
|
||||
self.label = np.concatenate((self.label, l_add), axis=0)
|
||||
|
||||
self.identity = np.eye(3)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
im = self.data[index]
|
||||
im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
|
||||
elif FLAGS.joint_rot:
|
||||
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
|
||||
else:
|
||||
label_size = self.label[index, 2:3]
|
||||
|
||||
label_pos = self.label[index, 4:]
|
||||
|
||||
return (im_corrupt, im, label_size, label_pos)
|
||||
|
||||
|
||||
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
model_size = kvs['model_size']
|
||||
weight_size = kvs['weight_size']
|
||||
x_mod = kvs['X_NOISE']
|
||||
|
||||
label_output = LABEL_SIZE
|
||||
for i in range(FLAGS.num_steps):
|
||||
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
|
||||
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
|
||||
label_grad = tf.gradients(e_noise, [label_output])[0]
|
||||
# label_grad = tf.Print(label_grad, [label_grad])
|
||||
label_output = label_output - 1.0 * label_grad
|
||||
label_output = tf.clip_by_value(label_output, 0.5, 1.0)
|
||||
|
||||
diffs = []
|
||||
for i in range(30):
|
||||
s = i*FLAGS.batch_size
|
||||
d = (i+1)*FLAGS.batch_size
|
||||
data_i = data[s:d]
|
||||
latent_i = latents[s:d]
|
||||
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
|
||||
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
|
||||
size_pred = sess.run([label_output], feed_dict)[0]
|
||||
size_gt = latent_i[:, 2:3]
|
||||
|
||||
diffs.append(np.abs(size_pred - size_gt).mean())
|
||||
|
||||
print(np.array(diffs).mean())
|
||||
|
||||
|
||||
def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
||||
# tf.reset_default_graph()
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5)
|
||||
LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32)
|
||||
else:
|
||||
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
|
||||
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
|
||||
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
|
||||
|
||||
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
|
||||
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
|
||||
loss_sq = tf.reduce_mean(tf.square(X_out - X_label))
|
||||
|
||||
optimizer = AdamOptimizer(1e-3)
|
||||
gvs = optimizer.compute_gradients(loss_sq)
|
||||
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
||||
train_op = optimizer.apply_gradients(gvs)
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
|
||||
datafull = data
|
||||
|
||||
itr = 0
|
||||
saver = tf.train.Saver()
|
||||
|
||||
vs = optimizer.variables()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.train:
|
||||
for _ in range(5):
|
||||
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
||||
|
||||
data_corrupt = data_corrupt.numpy()
|
||||
label_size, label_pos = label_size.numpy(), label_pos.numpy()
|
||||
|
||||
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
|
||||
label_comb = np.concatenate([label_size, label_pos], axis=1)
|
||||
|
||||
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
|
||||
|
||||
output = [loss_sq, train_op]
|
||||
|
||||
loss, _ = sess.run(output, feed_dict=feed_dict)
|
||||
|
||||
itr += 1
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
|
||||
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
|
||||
latent_pos = latent[:, 4:6]
|
||||
latent = np.concatenate([latent_size, latent_pos], axis=1)
|
||||
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
|
||||
else:
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
|
||||
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
|
||||
# print(loss)
|
||||
losses.append(loss)
|
||||
|
||||
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
|
||||
|
||||
|
||||
data_try = data_gen[:10]
|
||||
data_init = np.random.randn(10, 2*FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
|
||||
latent_pos = latents_gen[:10, 4:]
|
||||
else:
|
||||
latent_scale = latents_gen[:10, 2:3]
|
||||
latent_pos = latents_gen[:10, 4:]
|
||||
|
||||
latent_tot = np.concatenate([latent_scale, latent_pos], axis=1)
|
||||
|
||||
feed_dict = {X_feed: data_init, LABEL: latent_tot}
|
||||
x_output = sess.run([X_out], feed_dict=feed_dict)[0]
|
||||
x_output = np.clip(x_output, 0, 1)
|
||||
|
||||
im_name = "size_scale_combine_genbaseline.png"
|
||||
|
||||
x_output_wrap = np.ones((10, 66, 66))
|
||||
data_try_wrap = np.ones((10, 66, 66))
|
||||
|
||||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
|
||||
def gentest(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
datafull = data
|
||||
# Test combination of generalization where we use slices of both training
|
||||
x_final = X_NOISE
|
||||
x_mod_size = X_NOISE
|
||||
x_mod_pos = X_NOISE
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
||||
# use cond_pos
|
||||
|
||||
energies = []
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
|
||||
|
||||
# energies.append(e_noise)
|
||||
x_grad = tf.gradients(e_noise, [x_final])[0]
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
||||
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
# use cond_shape
|
||||
e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE)
|
||||
elif FLAGS.joint_rot:
|
||||
e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT)
|
||||
else:
|
||||
# use cond_size
|
||||
e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE)
|
||||
|
||||
# energies.append(e_noise)
|
||||
# energy_stack = tf.concat(energies, axis=1)
|
||||
# energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1)
|
||||
# energy_stack = tf.reduce_sum(energy_stack, axis=1)
|
||||
|
||||
x_grad = tf.gradients(e_noise, [x_mod_pos])[0]
|
||||
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
||||
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
||||
|
||||
# for x_mod_size
|
||||
# use cond_size
|
||||
# e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE)
|
||||
# x_grad = tf.gradients(e_noise, [x_mod_size])[0]
|
||||
# x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
|
||||
# x_mod_size = x_mod_size - FLAGS.step_lr * x_grad
|
||||
# x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
|
||||
|
||||
# # use cond_pos
|
||||
# e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS)
|
||||
# x_grad = tf.gradients(e_noise, [x_mod_size])[0]
|
||||
# x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
|
||||
# x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad)
|
||||
# x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
|
||||
|
||||
x_mod = x_mod_pos
|
||||
x_final = x_mod
|
||||
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
|
||||
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
elif FLAGS.joint_rot:
|
||||
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
|
||||
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
else:
|
||||
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
|
||||
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
|
||||
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
|
||||
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
||||
neg_loss = coeff * (-1*energy_neg) / norm_constant
|
||||
|
||||
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
|
||||
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
|
||||
|
||||
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
|
||||
gvs = optimizer.compute_gradients(loss_total)
|
||||
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
||||
train_op = optimizer.apply_gradients(gvs)
|
||||
|
||||
vs = optimizer.variables()
|
||||
sess.run(tf.variables_initializer(vs))
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
|
||||
x_off = tf.reduce_mean(tf.square(x_mod - X))
|
||||
|
||||
itr = 0
|
||||
saver = tf.train.Saver()
|
||||
x_mod = None
|
||||
|
||||
|
||||
if FLAGS.train:
|
||||
replay_buffer = ReplayBuffer(10000)
|
||||
for _ in range(1):
|
||||
|
||||
|
||||
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
||||
data_corrupt = data_corrupt.numpy()[:, :, :]
|
||||
data = data.numpy()[:, :, :]
|
||||
|
||||
if x_mod is not None:
|
||||
replay_buffer.add(x_mod)
|
||||
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
||||
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
|
||||
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
|
||||
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
|
||||
itr += 1
|
||||
|
||||
if itr % 10 == 0:
|
||||
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
|
||||
|
||||
if itr == FLAGS.break_steps:
|
||||
break
|
||||
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
||||
elif FLAGS.joint_rot:
|
||||
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
|
||||
x = 0.5 + np.random.randn(*dat.shape)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
else:
|
||||
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
|
||||
for i in range(2):
|
||||
x = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
feed_dict[X_NOISE] = x
|
||||
|
||||
loss = sess.run([x_off], feed_dict=feed_dict)[0]
|
||||
losses.append(loss)
|
||||
|
||||
print("Mean MSE loss of {} ".format(np.mean(losses)))
|
||||
|
||||
data_try = data_gen[:10]
|
||||
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
||||
latent_scale = latents_gen[:10, 2:3]
|
||||
latent_pos = latents_gen[:10, 4:]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
|
||||
|
||||
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
im_name = "size_shape_combine_gentest.png"
|
||||
else:
|
||||
im_name = "size_scale_combine_gentest.png"
|
||||
|
||||
x_output_wrap = np.ones((10, 66, 66))
|
||||
data_try_wrap = np.ones((10, 66, 66))
|
||||
|
||||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
|
||||
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
|
||||
x_mod = X_NOISE
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
||||
if FLAGS.cond_scale:
|
||||
e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE)
|
||||
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
||||
x_mod = x_mod - FLAGS.step_lr * x_grad
|
||||
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
||||
|
||||
if FLAGS.cond_shape:
|
||||
e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE)
|
||||
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
||||
x_mod = x_mod - FLAGS.step_lr * x_grad
|
||||
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
||||
|
||||
if FLAGS.cond_pos:
|
||||
e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS)
|
||||
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
||||
x_mod = x_mod - FLAGS.step_lr * x_grad
|
||||
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
||||
|
||||
if FLAGS.cond_rot:
|
||||
e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT)
|
||||
x_grad = tf.gradients(e_noise, [x_mod])[0]
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
||||
x_mod = x_mod - FLAGS.step_lr * x_grad
|
||||
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
||||
|
||||
print("Finished constructing loop {}".format(i))
|
||||
|
||||
x_final = x_mod
|
||||
|
||||
data_try = data[:10]
|
||||
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
||||
label_scale = latents[:10, 2:3]
|
||||
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
|
||||
label_rot = latents[:10, 3:4]
|
||||
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
|
||||
label_pos = latents[:10, 4:]
|
||||
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot}
|
||||
x_out = sess.run([x_final], feed_dict)[0]
|
||||
|
||||
im_name = "im"
|
||||
|
||||
if FLAGS.cond_scale:
|
||||
im_name += "_condscale"
|
||||
|
||||
if FLAGS.cond_shape:
|
||||
im_name += "_condshape"
|
||||
|
||||
if FLAGS.cond_pos:
|
||||
im_name += "_condpos"
|
||||
|
||||
if FLAGS.cond_rot:
|
||||
im_name += "_condrot"
|
||||
|
||||
im_name += ".png"
|
||||
|
||||
x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66))
|
||||
x_out_pad[:, 1:-1, 1:-1] = x_out
|
||||
data_try_pad[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
def main():
|
||||
data = np.load(FLAGS.dsprites_path)['imgs']
|
||||
l = latents = np.load(FLAGS.dsprites_path)['latents_values']
|
||||
|
||||
np.random.seed(1)
|
||||
idx = np.random.permutation(data.shape[0])
|
||||
|
||||
data = data[idx]
|
||||
latents = latents[idx]
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.Session(config=config)
|
||||
|
||||
# Model 1 will be conditioned on size
|
||||
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
|
||||
weight_size = model_size.construct_weights('context_0')
|
||||
|
||||
# Model 2 will be conditioned on shape
|
||||
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
|
||||
weight_shape = model_shape.construct_weights('context_1')
|
||||
|
||||
# Model 3 will be conditioned on position
|
||||
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
|
||||
weight_pos = model_pos.construct_weights('context_2')
|
||||
|
||||
# Model 4 will be conditioned on rotation
|
||||
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
|
||||
weight_rot = model_rot.construct_weights('context_3')
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
|
||||
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
|
||||
|
||||
if FLAGS.cond_scale:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_size)
|
||||
|
||||
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
|
||||
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
|
||||
|
||||
if FLAGS.cond_shape:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_shape)
|
||||
|
||||
|
||||
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
|
||||
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_pos:
|
||||
saver.restore(sess, save_path_pos)
|
||||
|
||||
|
||||
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
|
||||
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_rot:
|
||||
saver.restore(sess, save_path_rot)
|
||||
|
||||
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
||||
LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
|
||||
x_mod = X_NOISE
|
||||
|
||||
kvs = {}
|
||||
kvs['X_NOISE'] = X_NOISE
|
||||
kvs['LABEL_SIZE'] = LABEL_SIZE
|
||||
kvs['LABEL_SHAPE'] = LABEL_SHAPE
|
||||
kvs['LABEL_POS'] = LABEL_POS
|
||||
kvs['LABEL_ROT'] = LABEL_ROT
|
||||
kvs['model_size'] = model_size
|
||||
kvs['model_shape'] = model_shape
|
||||
kvs['model_pos'] = model_pos
|
||||
kvs['model_rot'] = model_rot
|
||||
kvs['weight_size'] = weight_size
|
||||
kvs['weight_shape'] = weight_shape
|
||||
kvs['weight_pos'] = weight_pos
|
||||
kvs['weight_rot'] = weight_rot
|
||||
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
|
||||
if FLAGS.task == 'conceptcombine':
|
||||
conceptcombine(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'labeldiscover':
|
||||
labeldiscover(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'gentest':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
gentest(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'genbaseline':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
if FLAGS.plot_curve:
|
||||
mse_losses = []
|
||||
for frac in [i/10 for i in range(11)]:
|
||||
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
|
||||
mse_losses.append(mse_loss)
|
||||
np.save("mse_baseline_comb.npy", mse_losses)
|
||||
else:
|
||||
genbaseline(sess, kvs, data, latents, save_exp_dir)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
981
EBMs/ebm_sandbox.py
Normal file
981
EBMs/ebm_sandbox.py
Normal file
|
@ -0,0 +1,981 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, DspritesNet
|
||||
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, DSprites
|
||||
from utils import optimistic_restore, set_seed
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from baselines.logger import TensorBoardOutputFormat
|
||||
from scipy.misc import imsave
|
||||
import os
|
||||
import sklearn.metrics as sk
|
||||
from baselines.common.tf_util import initialize
|
||||
from scipy.linalg import eig
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# set_seed(1)
|
||||
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites')
|
||||
flags.DEFINE_string('logdir', 'sandbox_cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('task', 'label', 'using conditional energy based models for classification'
|
||||
'anticorrupt: restore salt and pepper noise),'
|
||||
' boxcorrupt: restore empty portion of image'
|
||||
'or crossclass: change images from one class to another'
|
||||
'or cycleclass: view image change across a label'
|
||||
'or nearestneighbor which returns the nearest images in the test set'
|
||||
'or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)'
|
||||
'or mixenergy to evaluate out of distribution generalization compared to other datasets')
|
||||
flags.DEFINE_bool('hessian', True, 'Whether to use the hessian or the Jacobian for latent traversals')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
||||
flags.DEFINE_integer('batch_size', 32, 'Size of inputs')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('train', True, 'Whether to train or test network')
|
||||
flags.DEFINE_bool('single', False, 'whether to use one sample to debug')
|
||||
flags.DEFINE_bool('cclass', True, 'whether to use a conditional model (required for task label)')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_float('step_lr', 10.0, 'step size for updates on label')
|
||||
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
|
||||
flags.DEFINE_bool('large_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('larger_model', False, 'Whether to use a larger model')
|
||||
flags.DEFINE_bool('wider_model', False, 'Whether to use a widermodel model')
|
||||
flags.DEFINE_bool('svhn', False, 'Whether to test on SVHN')
|
||||
|
||||
# Conditions for mixenergy (outlier detection)
|
||||
flags.DEFINE_bool('svhnmix', False, 'Whether to test mix on SVHN')
|
||||
flags.DEFINE_bool('cifar100mix', False, 'Whether to test mix on CIFAR100')
|
||||
flags.DEFINE_bool('texturemix', False, 'Whether to test mix on Textures dataset')
|
||||
flags.DEFINE_bool('randommix', False, 'Whether to test mix on random dataset')
|
||||
|
||||
# Conditions for label task (adversarial classification)
|
||||
flags.DEFINE_integer('lival', 8, 'Value of constraint for li')
|
||||
flags.DEFINE_integer('l2val', 40, 'Value of constraint for l2')
|
||||
flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run')
|
||||
flags.DEFINE_integer('lnorm', -1, 'linfinity is -1, l2 norm is 2')
|
||||
flags.DEFINE_bool('labelgrid', False, 'Make a grid of labels')
|
||||
|
||||
# Conditions on which models to use
|
||||
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
|
||||
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
|
||||
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
|
||||
flags.DEFINE_bool('cond_size', True, 'whether to condition on scale')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def rescale_im(im):
|
||||
im = np.clip(im, 0, 1)
|
||||
return np.round(im * 255).astype(np.uint8)
|
||||
|
||||
def label(dataloader, test_dataloader, target_vars, sess, l1val=8, l2val=40):
|
||||
X = target_vars['X']
|
||||
Y = target_vars['Y']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
accuracy = target_vars['accuracy']
|
||||
train_op = target_vars['train_op']
|
||||
l1_norm = target_vars['l1_norm']
|
||||
l2_norm = target_vars['l2_norm']
|
||||
|
||||
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
|
||||
label_init = label_init / label_init.sum(axis=1, keepdims=True)
|
||||
|
||||
label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
|
||||
label_init = np.reshape(label_init, (-1, 10))
|
||||
|
||||
for i in range(1):
|
||||
emp_accuracies = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
|
||||
emp_accuracy = sess.run([accuracy], feed_dict)
|
||||
emp_accuracies.append(emp_accuracy)
|
||||
print(np.array(emp_accuracies).mean())
|
||||
|
||||
print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
|
||||
|
||||
return np.array(emp_accuracies).mean()
|
||||
|
||||
|
||||
def labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=8, l2val=40):
|
||||
X = target_vars['X']
|
||||
Y = target_vars['Y']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
accuracy = target_vars['accuracy']
|
||||
train_op = target_vars['train_op']
|
||||
l1_norm = target_vars['l1_norm']
|
||||
l2_norm = target_vars['l2_norm']
|
||||
|
||||
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
|
||||
label_init = label_init / label_init.sum(axis=1, keepdims=True)
|
||||
|
||||
label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
|
||||
label_init = np.reshape(label_init, (-1, 10))
|
||||
|
||||
itr = 0
|
||||
|
||||
if FLAGS.train:
|
||||
for i in range(1):
|
||||
for data_corrupt, data, label_gt in tqdm(dataloader):
|
||||
feed_dict = {X: data, Y_GT: label_gt, Y: label_init}
|
||||
acc, _ = sess.run([accuracy, train_op], feed_dict)
|
||||
|
||||
itr += 1
|
||||
|
||||
if itr % 10 == 0:
|
||||
print(acc)
|
||||
|
||||
saver.save(sess, osp.join(savedir, "model_supervised"))
|
||||
|
||||
saver.restore(sess, osp.join(savedir, "model_supervised"))
|
||||
|
||||
|
||||
for i in range(1):
|
||||
emp_accuracies = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
|
||||
emp_accuracy = sess.run([accuracy], feed_dict)
|
||||
emp_accuracies.append(emp_accuracy)
|
||||
print(np.array(emp_accuracies).mean())
|
||||
|
||||
|
||||
print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
|
||||
|
||||
return np.array(emp_accuracies).mean()
|
||||
|
||||
|
||||
def energyeval(dataloader, test_dataloader, target_vars, sess):
|
||||
X = target_vars['X']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
energy = target_vars['energy']
|
||||
energy_end = target_vars['energy_end']
|
||||
|
||||
test_energies = []
|
||||
train_energies = []
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
feed_dict = {X: data, Y_GT: label_gt}
|
||||
test_energy = sess.run([energy], feed_dict)[0]
|
||||
test_energies.extend(list(test_energy))
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(dataloader):
|
||||
feed_dict = {X: data, Y_GT: label_gt}
|
||||
train_energy = sess.run([energy], feed_dict)[0]
|
||||
train_energies.extend(list(train_energy))
|
||||
|
||||
print(len(train_energies))
|
||||
print(len(test_energies))
|
||||
|
||||
print("Train energies of {} with std {}".format(np.mean(train_energies), np.std(train_energies)))
|
||||
print("Test energies of {} with std {}".format(np.mean(test_energies), np.std(test_energies)))
|
||||
|
||||
np.save("train_ebm.npy", train_energies)
|
||||
np.save("test_ebm.npy", test_energies)
|
||||
|
||||
|
||||
def energyevalmix(dataloader, test_dataloader, target_vars, sess):
|
||||
X = target_vars['X']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
energy = target_vars['energy']
|
||||
|
||||
if FLAGS.svhnmix:
|
||||
dataset = Svhn(train=False)
|
||||
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
|
||||
test_iter = iter(test_dataloader_val)
|
||||
elif FLAGS.cifar100mix:
|
||||
dataset = Cifar100(train=False)
|
||||
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
|
||||
test_iter = iter(test_dataloader_val)
|
||||
elif FLAGS.texturemix:
|
||||
dataset = Textures()
|
||||
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
|
||||
test_iter = iter(test_dataloader_val)
|
||||
|
||||
probs = []
|
||||
labels = []
|
||||
negs = []
|
||||
pos = []
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
data = data.numpy()
|
||||
data_corrupt = data_corrupt.numpy()
|
||||
if FLAGS.svhnmix:
|
||||
_, data_mix, _ = test_iter.next()
|
||||
elif FLAGS.cifar100mix:
|
||||
_, data_mix, _ = test_iter.next()
|
||||
elif FLAGS.texturemix:
|
||||
_, data_mix, _ = test_iter.next()
|
||||
elif FLAGS.randommix:
|
||||
data_mix = np.random.randn(FLAGS.batch_size, 32, 32, 3) * 0.5 + 0.5
|
||||
else:
|
||||
data_idx = np.concatenate([np.arange(1, data.shape[0]), [0]])
|
||||
data_other = data[data_idx]
|
||||
data_mix = (data + data_other) / 2
|
||||
|
||||
data_mix = data_mix[:data.shape[0]]
|
||||
|
||||
if FLAGS.cclass:
|
||||
# It's unfair to take a random class
|
||||
label_gt= np.tile(np.eye(10), (data.shape[0], 1, 1))
|
||||
label_gt = label_gt.reshape(data.shape[0] * 10, 10)
|
||||
data_mix = np.tile(data_mix[:, None, :, :, :], (1, 10, 1, 1, 1))
|
||||
data = np.tile(data[:, None, :, :, :], (1, 10, 1, 1, 1))
|
||||
|
||||
data_mix = data_mix.reshape(-1, 32, 32, 3)
|
||||
data = data.reshape(-1, 32, 32, 3)
|
||||
|
||||
|
||||
feed_dict = {X: data, Y_GT: label_gt}
|
||||
feed_dict_neg = {X: data_mix, Y_GT: label_gt}
|
||||
|
||||
pos_energy = sess.run([energy], feed_dict)[0]
|
||||
neg_energy = sess.run([energy], feed_dict_neg)[0]
|
||||
|
||||
if FLAGS.cclass:
|
||||
pos_energy = pos_energy.reshape(-1, 10).min(axis=1)
|
||||
neg_energy = neg_energy.reshape(-1, 10).min(axis=1)
|
||||
|
||||
probs.extend(list(-1*pos_energy))
|
||||
probs.extend(list(-1*neg_energy))
|
||||
pos.extend(list(-1*pos_energy))
|
||||
negs.extend(list(-1*neg_energy))
|
||||
labels.extend([1]*pos_energy.shape[0])
|
||||
labels.extend([0]*neg_energy.shape[0])
|
||||
|
||||
pos, negs = np.array(pos), np.array(negs)
|
||||
np.save("pos.npy", pos)
|
||||
np.save("neg.npy", negs)
|
||||
auroc = sk.roc_auc_score(labels, probs)
|
||||
print("Roc score of {}".format(auroc))
|
||||
|
||||
|
||||
def anticorrupt(dataloader, weights, model, target_vars, logdir, sess):
|
||||
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
|
||||
for data_corrupt, data, label_gt in tqdm(dataloader):
|
||||
data, label_gt = data.numpy(), label_gt.numpy()
|
||||
|
||||
noise = np.random.uniform(0, 1, size=[data.shape[0], data.shape[1], data.shape[2]])
|
||||
low_mask = noise < 0.05
|
||||
high_mask = (noise > 0.05) & (noise < 0.1)
|
||||
|
||||
print(high_mask.shape)
|
||||
|
||||
data_corrupt = data.copy()
|
||||
data_corrupt[low_mask] = 0.1
|
||||
data_corrupt[high_mask] = 0.9
|
||||
data_corrupt_init = data_corrupt
|
||||
|
||||
for i in range(5):
|
||||
feed_dict = {X: data_corrupt, Y_GT: label_gt}
|
||||
data_corrupt = sess.run([X_final], feed_dict)[0]
|
||||
|
||||
data_uncorrupt = data_corrupt
|
||||
data_corrupt, data_uncorrupt, data = rescale_im(data_corrupt_init), rescale_im(data_uncorrupt), rescale_im(data)
|
||||
|
||||
panel_im = np.zeros((32*20, 32*3, 3)).astype(np.uint8)
|
||||
|
||||
for i in range(20):
|
||||
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
|
||||
panel_im[32*i:32*i+32, 32:64] = data_uncorrupt[i]
|
||||
panel_im[32*i:32*i+32, 64:] = data[i]
|
||||
|
||||
imsave(osp.join(logdir, "anticorrupt.png"), panel_im)
|
||||
assert False
|
||||
|
||||
|
||||
def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess):
|
||||
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
|
||||
eval_im = 10000
|
||||
|
||||
data_diff = []
|
||||
for data_corrupt, data, label_gt in tqdm(dataloader):
|
||||
data, label_gt = data.numpy(), label_gt.numpy()
|
||||
data_uncorrupts = []
|
||||
|
||||
data_corrupt = data.copy()
|
||||
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
|
||||
|
||||
data_corrupt_init = data_corrupt
|
||||
|
||||
for j in range(10):
|
||||
feed_dict = {X: data_corrupt, Y_GT: label_gt}
|
||||
data_corrupt = sess.run([X_final], feed_dict)[0]
|
||||
|
||||
val = np.mean(np.square(data_corrupt - data), axis=(1, 2, 3))
|
||||
data_diff.extend(list(val))
|
||||
|
||||
if len(data_diff) > eval_im:
|
||||
break
|
||||
|
||||
print("Mean {} and std {} for train dataloader".format(np.mean(data_diff), np.std(data_diff)))
|
||||
|
||||
np.save("data_diff_train_image.npy", data_diff)
|
||||
|
||||
data_diff = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
data, label_gt = data.numpy(), label_gt.numpy()
|
||||
data_uncorrupts = []
|
||||
|
||||
data_corrupt = data.copy()
|
||||
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
|
||||
|
||||
data_corrupt_init = data_corrupt
|
||||
|
||||
for j in range(10):
|
||||
feed_dict = {X: data_corrupt, Y_GT: label_gt}
|
||||
data_corrupt = sess.run([X_final], feed_dict)[0]
|
||||
|
||||
data_diff.extend(list(np.mean(np.square(data_corrupt - data), axis=(1, 2, 3))))
|
||||
|
||||
if len(data_diff) > eval_im:
|
||||
break
|
||||
|
||||
print("Mean {} and std {} for test dataloader".format(np.mean(data_diff), np.std(data_diff)))
|
||||
|
||||
np.save("data_diff_test_image.npy", data_diff)
|
||||
|
||||
|
||||
def crossclass(dataloader, weights, model, target_vars, logdir, sess):
|
||||
X, Y_GT, X_mods, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods'], target_vars['X_final']
|
||||
for data_corrupt, data, label_gt in tqdm(dataloader):
|
||||
data, label_gt = data.numpy(), label_gt.numpy()
|
||||
data_corrupt = data.copy()
|
||||
data_corrupt[1:] = data_corrupt[0:-1]
|
||||
data_corrupt[0] = data[-1]
|
||||
|
||||
data_mods = []
|
||||
data_mod = data_corrupt
|
||||
|
||||
for i in range(10):
|
||||
data_mods.append(data_mod)
|
||||
|
||||
feed_dict = {X: data_mod, Y_GT: label_gt}
|
||||
data_mod = sess.run(X_final, feed_dict)
|
||||
|
||||
|
||||
|
||||
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
|
||||
|
||||
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
|
||||
|
||||
panel_im = np.zeros((32*20, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
|
||||
|
||||
for i in range(20):
|
||||
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
|
||||
|
||||
for j in range(len(data_mods)):
|
||||
panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
|
||||
|
||||
panel_im[32*i:32*i+32, -32:] = data[i]
|
||||
|
||||
imsave(osp.join(logdir, "crossclass.png"), panel_im)
|
||||
assert False
|
||||
|
||||
|
||||
def cycleclass(dataloader, weights, model, target_vars, logdir, sess):
|
||||
# X, Y_GT, X_final, X_targ = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'], target_vars['X_targ']
|
||||
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
|
||||
for data_corrupt, data, label_gt in tqdm(dataloader):
|
||||
data, label_gt = data.numpy(), label_gt.numpy()
|
||||
data_corrupt = data_corrupt.numpy()
|
||||
|
||||
|
||||
data_mods = []
|
||||
x_curr = data_corrupt
|
||||
x_target = np.random.uniform(0, 1, data_corrupt.shape)
|
||||
# x_target = np.tile(x_target, (1, 32, 32, 1))
|
||||
|
||||
|
||||
for i in range(20):
|
||||
feed_dict = {X: x_curr, Y_GT: label_gt}
|
||||
x_curr_new = sess.run(X_final, feed_dict)
|
||||
x_curr = x_curr_new
|
||||
data_mods.append(x_curr_new)
|
||||
|
||||
if i > 30:
|
||||
x_target = np.random.uniform(0, 1, data_corrupt.shape)
|
||||
|
||||
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
|
||||
|
||||
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
|
||||
|
||||
panel_im = np.zeros((32*100, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
|
||||
|
||||
for i in range(100):
|
||||
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
|
||||
|
||||
for j in range(len(data_mods)):
|
||||
panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
|
||||
|
||||
panel_im[32*i:32*i+32, -32:] = data[i]
|
||||
|
||||
imsave(osp.join(logdir, "cycleclass.png"), panel_im)
|
||||
assert False
|
||||
|
||||
|
||||
def democlass(dataloader, weights, model, target_vars, logdir, sess):
|
||||
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
|
||||
panel_im = np.zeros((5*32, 10*32, 3)).astype(np.uint8)
|
||||
for i in range(10):
|
||||
data_corrupt = np.random.uniform(0, 1, (5, 32, 32, 3))
|
||||
label_gt = np.tile(np.eye(10)[i:i+1], (5, 1))
|
||||
|
||||
feed_dict = {X: data_corrupt, Y_GT: label_gt}
|
||||
x_final = sess.run([X_final], feed_dict)[0]
|
||||
|
||||
x_final = rescale_im(x_final)
|
||||
|
||||
row = i // 2
|
||||
col = i % 2
|
||||
|
||||
start_idx = col * 32 * 5
|
||||
row_idx = row * 32
|
||||
|
||||
for j in range(5):
|
||||
panel_im[row_idx:row_idx+32, start_idx+j*32:start_idx+(j+1) * 32] = x_final[j]
|
||||
|
||||
imsave(osp.join(logdir, "democlass.png"), panel_im)
|
||||
|
||||
|
||||
def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars):
|
||||
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
|
||||
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
|
||||
|
||||
def compute_logit(X, stop_grad=False, num_steps=0):
|
||||
batch_size = tf.shape(X)[0]
|
||||
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
|
||||
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
|
||||
Y_new = tf.reshape(Y, (batch_size*10, 10))
|
||||
|
||||
X_min = X - 8 / 255.
|
||||
X_max = X + 8 / 255.
|
||||
|
||||
for i in range(num_steps):
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
|
||||
|
||||
energy_noise = model.forward(X, weights, label=Y, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
||||
|
||||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad
|
||||
X = tf.maximum(tf.minimum(X, X_max), X_min)
|
||||
|
||||
energy = model.forward(X, weight, label=Y_new)
|
||||
energy = -tf.reshape(energy, (batch_size, 10))
|
||||
|
||||
if stop_grad:
|
||||
energy = tf.stop_gradient(energy)
|
||||
|
||||
return energy
|
||||
|
||||
for i in range(FLAGS.pgd):
|
||||
if FLAGS.train:
|
||||
break
|
||||
|
||||
print("Constructed loop {} of pgd attack".format(i))
|
||||
X_init = X
|
||||
if i == 0:
|
||||
X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
|
||||
|
||||
logit = compute_logit(X)
|
||||
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
|
||||
|
||||
x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
|
||||
X = X + 2 * x_grad
|
||||
|
||||
if FLAGS.lnorm == -1:
|
||||
X = tf.maximum(tf.minimum(X, X_max), X_min)
|
||||
elif FLAGS.lnorm == 2:
|
||||
X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
|
||||
|
||||
|
||||
energy = compute_logit(X, num_steps=0)
|
||||
logits = energy
|
||||
labels = tf.argmax(Y_GT, axis=1)
|
||||
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logits)
|
||||
|
||||
|
||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||
train_op = optimizer.minimize(loss)
|
||||
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, axis=1), labels)
|
||||
|
||||
target_vars['accuracy'] = accuracy
|
||||
target_vars['train_op'] = train_op
|
||||
target_vars['l1_norm'] = l1_norm
|
||||
target_vars['l2_norm'] = l2_norm
|
||||
|
||||
|
||||
def construct_latent(weights, X, Y_GT, model, target_vars):
|
||||
|
||||
eps = 0.001
|
||||
X_init = X[0:1]
|
||||
|
||||
def traversals(model, X, weights, Y_GT):
|
||||
if FLAGS.hessian:
|
||||
e_pos = model.forward(X, weights, label=Y_GT)
|
||||
hessian = tf.hessians(e_pos, X)
|
||||
hessian = tf.reshape(hessian, (1, 64*64, 64*64))[0]
|
||||
e, v = tf.linalg.eigh(hessian)
|
||||
else:
|
||||
latent = model.forward(X, weights, label=Y_GT, return_logit=True)
|
||||
latents = tf.split(latent, 128, axis=1)
|
||||
jacobian = [tf.gradients(latent, X)[0] for latent in latents]
|
||||
jacobian = tf.stack(jacobian, axis=1)
|
||||
jacobian = tf.reshape(jacobian, (tf.shape(jacobian)[1], tf.shape(jacobian)[1], 64*64))
|
||||
s, _, v = tf.linalg.svd(jacobian)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
var_scale = 1.0
|
||||
n = 3
|
||||
xs = []
|
||||
|
||||
v = traversals(model, X_init, weights, Y_GT)
|
||||
|
||||
for i in range(n):
|
||||
var = tf.reshape(v[:, i], (1, 64, 64))
|
||||
X_plus = X_init - var_scale * var
|
||||
X_min = X_init + var_scale * var
|
||||
|
||||
xs.extend([X_plus, X_min])
|
||||
|
||||
x_stack = tf.stack(xs, axis=0)
|
||||
|
||||
e_pos_hess_modify = model.forward(x_stack, weights, label=Y_GT)
|
||||
|
||||
for i in range(20):
|
||||
x_stack = x_stack + tf.random_normal(tf.shape(x_stack), mean=0.0, stddev=0.005)
|
||||
e_pos = model.forward(x_stack, weights, label=Y_GT)
|
||||
|
||||
x_grad = tf.gradients(e_pos, [x_stack])[0]
|
||||
x_stack = x_stack - 4*FLAGS.step_lr * x_grad
|
||||
|
||||
x_stack = tf.clip_by_value(x_stack, 0, 1)
|
||||
|
||||
x_mods = tf.split(X, 6)
|
||||
|
||||
eigs = []
|
||||
for j in range(6):
|
||||
x_mod = x_mods[j]
|
||||
v = traversals(model, x_mod, weights, Y_GT)
|
||||
|
||||
idx = j // 2
|
||||
var = tf.reshape(v[:, idx], (1, 64, 64))
|
||||
|
||||
if j % 2 == 1:
|
||||
x_mod = x_mod + var_scale * var
|
||||
eigs.append(var)
|
||||
else:
|
||||
x_mod = x_mod - var_scale * var
|
||||
eigs.append(-var)
|
||||
|
||||
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
||||
x_mods[j] = x_mod
|
||||
|
||||
x_mods_stack = tf.stack(x_mods, axis=0)
|
||||
|
||||
eigs_stack = tf.stack(eigs, axis=0)
|
||||
energys = []
|
||||
|
||||
for i in range(20):
|
||||
x_mods_stack = x_mods_stack + tf.random_normal(tf.shape(x_mods_stack), mean=0.0, stddev=0.005)
|
||||
e_pos = model.forward(x_mods_stack, weights, label=Y_GT)
|
||||
|
||||
x_grad = tf.gradients(e_pos, [x_mods_stack])[0]
|
||||
x_mods_stack = x_mods_stack - 4*FLAGS.step_lr * x_grad
|
||||
# x_mods_stack = x_mods_stack + 0.1 * eigs_stack
|
||||
|
||||
x_mods_stack = tf.clip_by_value(x_mods_stack, 0, 1)
|
||||
|
||||
energys.append(e_pos)
|
||||
|
||||
x_refine = x_mods_stack
|
||||
es = tf.stack(energys, axis=0)
|
||||
|
||||
# target_vars['hessian'] = hessian
|
||||
# target_vars['e'] = e
|
||||
target_vars['v'] = v
|
||||
target_vars['x_stack'] = x_stack
|
||||
target_vars['x_refine'] = x_refine
|
||||
target_vars['es'] = es
|
||||
# target_vars['e_base'] = e_pos_base
|
||||
|
||||
|
||||
def latent(test_dataloader, weights, model, target_vars, sess):
|
||||
X = target_vars['X']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
# hessian = target_vars['hessian']
|
||||
# e = target_vars['e']
|
||||
v = target_vars['v']
|
||||
x_stack = target_vars['x_stack']
|
||||
x_refine = target_vars['x_refine']
|
||||
es = target_vars['es']
|
||||
# e_pos_base = target_vars['e_base']
|
||||
# e_pos_hess_modify = target_vars['e_pos_hessian']
|
||||
|
||||
data_corrupt, data, label_gt = iter(test_dataloader).next()
|
||||
data = data.numpy()
|
||||
x_init = np.tile(data[0:1], (6, 1, 1))
|
||||
x_mod, = sess.run([x_stack], {X: data})
|
||||
# print("Value of original starting image: ", e_pos)
|
||||
# print("Value of energy of hessian: ", e_pos_hess)
|
||||
x_mod = x_mod.squeeze()
|
||||
|
||||
n = 6
|
||||
x_mod_list = [x_init, x_mod]
|
||||
|
||||
for i in range(n):
|
||||
x_mod, evals = sess.run([x_refine, es], {X: x_mod})
|
||||
x_mod = x_mod.squeeze()
|
||||
x_mod_list.append(x_mod)
|
||||
print("Value of energies after evaluation: ", evals)
|
||||
|
||||
x_mod_list = x_mod_list[:]
|
||||
|
||||
|
||||
series_xmod = np.stack(x_mod_list, axis=1)
|
||||
series_header = np.tile(data[0:1, None, :, :], (1, len(x_mod_list), 1, 1))
|
||||
|
||||
series_total = np.concatenate([series_header, series_xmod], axis=0)
|
||||
|
||||
series_total_full = np.ones((*series_total.shape[:-2], 66, 66))
|
||||
|
||||
series_total_full[:, :, 1:-1, 1:-1] = series_total
|
||||
|
||||
series_total = series_total_full
|
||||
|
||||
series_total = series_total.transpose((0, 2, 1, 3)).reshape((-1, len(x_mod_list)*66))
|
||||
im_total = rescale_im(series_total)
|
||||
imsave("latent_comb.png", im_total)
|
||||
|
||||
|
||||
def construct_label(weights, X, Y, Y_GT, model, target_vars):
|
||||
# for i in range(FLAGS.num_steps):
|
||||
# Y = Y + tf.random_normal(tf.shape(Y), mean=0.0, stddev=0.03)
|
||||
# e = model.forward(X, weights, label=Y)
|
||||
|
||||
# Y_grad = tf.clip_by_value(tf.gradients(e, [Y])[0], -1, 1)
|
||||
# Y = Y - 0.1 * Y_grad
|
||||
# Y = tf.clip_by_value(Y, 0, 1)
|
||||
|
||||
# Y = Y / tf.reduce_sum(Y, axis=[1], keepdims=True)
|
||||
|
||||
e_bias = tf.get_variable('e_bias', shape=10, initializer=tf.initializers.zeros())
|
||||
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
|
||||
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
|
||||
|
||||
def compute_logit(X, stop_grad=False, num_steps=0):
|
||||
batch_size = tf.shape(X)[0]
|
||||
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
|
||||
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
|
||||
Y_new = tf.reshape(Y, (batch_size*10, 10))
|
||||
|
||||
X_min = X - 8 / 255.
|
||||
X_max = X + 8 / 255.
|
||||
|
||||
for i in range(num_steps):
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
|
||||
|
||||
energy_noise = model.forward(X, weights, label=Y, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
||||
|
||||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad
|
||||
X = tf.maximum(tf.minimum(X, X_max), X_min)
|
||||
|
||||
energy = model.forward(X, weights, label=Y_new)
|
||||
energy = -tf.reshape(energy, (batch_size, 10))
|
||||
|
||||
if stop_grad:
|
||||
energy = tf.stop_gradient(energy)
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
# eps_norm = 30
|
||||
X_min = X - l1_norm / 255.
|
||||
X_max = X + l1_norm / 255.
|
||||
|
||||
for i in range(FLAGS.pgd):
|
||||
print("Constructed loop {} of pgd attack".format(i))
|
||||
X_init = X
|
||||
if i == 0:
|
||||
X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
|
||||
|
||||
logit = compute_logit(X)
|
||||
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
|
||||
|
||||
x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
|
||||
X = X + 2 * x_grad
|
||||
|
||||
if FLAGS.lnorm == -1:
|
||||
X = tf.maximum(tf.minimum(X, X_max), X_min)
|
||||
elif FLAGS.lnorm == 2:
|
||||
X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
|
||||
|
||||
energy_stopped = compute_logit(X, stop_grad=True, num_steps=FLAGS.num_steps) + e_bias
|
||||
|
||||
# # Y = tf.Print(Y, [Y])
|
||||
labels = tf.argmax(Y_GT, axis=1)
|
||||
# max_z = tf.argmax(energy_stopped, axis=1)
|
||||
|
||||
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=energy_stopped)
|
||||
optimizer = tf.train.AdamOptimizer(1e-2)
|
||||
train_op = optimizer.minimize(loss)
|
||||
|
||||
accuracy = tf.contrib.metrics.accuracy(tf.argmax(energy_stopped, axis=1), labels)
|
||||
target_vars['accuracy'] = accuracy
|
||||
target_vars['train_op'] = train_op
|
||||
target_vars['l1_norm'] = l1_norm
|
||||
target_vars['l2_norm'] = l2_norm
|
||||
|
||||
|
||||
def construct_energy(weights, X, Y, Y_GT, model, target_vars):
|
||||
energy = model.forward(X, weights, label=Y_GT)
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
|
||||
|
||||
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
||||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad
|
||||
X = tf.clip_by_value(X, 0, 1)
|
||||
|
||||
|
||||
target_vars['energy'] = energy
|
||||
target_vars['energy_end'] = energy_noise
|
||||
|
||||
|
||||
def construct_steps(weights, X, Y_GT, model, target_vars):
|
||||
n = 50
|
||||
scale_fac = 1.0
|
||||
|
||||
# if FLAGS.task == 'cycleclass':
|
||||
# scale_fac = 10.0
|
||||
|
||||
X_mods = []
|
||||
X = tf.identity(X)
|
||||
|
||||
mask = np.zeros((1, 32, 32, 3))
|
||||
|
||||
if FLAGS.task == "boxcorrupt":
|
||||
mask[:, 16:, :, :] = 1
|
||||
else:
|
||||
mask[:, :, :, :] = 1
|
||||
|
||||
mask = tf.Variable(tf.convert_to_tensor(mask, dtype=tf.float32), trainable=False)
|
||||
|
||||
# X_targ = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
X_old = X
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005*scale_fac) * mask
|
||||
|
||||
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
||||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad * scale_fac * mask
|
||||
X = tf.clip_by_value(X, 0, 1)
|
||||
|
||||
if i % n == (n-1):
|
||||
X_mods.append(X)
|
||||
|
||||
print("Constructing step {}".format(i))
|
||||
|
||||
target_vars['X_final'] = X
|
||||
target_vars['X_mods'] = X_mods
|
||||
|
||||
|
||||
def nearest_neighbor(dataset, sess, target_vars, logdir):
|
||||
X = target_vars['X']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
x_final = target_vars['X_final']
|
||||
|
||||
noise = np.random.uniform(0, 1, size=[10, 32, 32, 3])
|
||||
# label = np.random.randint(0, 10, size=[10])
|
||||
label = np.eye(10)
|
||||
|
||||
coarse = noise
|
||||
|
||||
for i in range(10):
|
||||
x_new = sess.run([x_final], {X:coarse, Y_GT:label})[0]
|
||||
coarse = x_new
|
||||
|
||||
x_new_dense = x_new.reshape(10, 1, 32*32*3)
|
||||
dataset_dense = dataset.reshape(1, 50000, 32*32*3)
|
||||
|
||||
diff = np.square(x_new_dense - dataset_dense).sum(axis=2)
|
||||
diff_idx = np.argsort(diff, axis=1)
|
||||
|
||||
panel = np.zeros((32*10, 32*6, 3))
|
||||
|
||||
dataset_rescale = rescale_im(dataset)
|
||||
x_new_rescale = rescale_im(x_new)
|
||||
|
||||
for i in range(10):
|
||||
panel[i*32:i*32+32, :32] = x_new_rescale[i]
|
||||
for j in range(5):
|
||||
panel[i*32:i*32+32, 32*j+32:32*j+64] = dataset_rescale[diff_idx[i, j]]
|
||||
|
||||
imsave(osp.join(logdir, "nearest.png"), panel)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
dataset = Cifar10(train=True, noise=False)
|
||||
test_dataset = Cifar10(train=False, noise=False)
|
||||
else:
|
||||
dataset = Imagenet(train=True)
|
||||
test_dataset = Imagenet(train=False)
|
||||
|
||||
if FLAGS.svhn:
|
||||
dataset = Svhn(train=True)
|
||||
test_dataset = Svhn(train=False)
|
||||
|
||||
if FLAGS.task == 'latent':
|
||||
dataset = DSprites()
|
||||
test_dataset = dataset
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
|
||||
|
||||
hidden_dim = 128
|
||||
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=hidden_dim)
|
||||
elif FLAGS.larger_model:
|
||||
model = ResNet32Larger(num_filters=hidden_dim)
|
||||
elif FLAGS.wider_model:
|
||||
if FLAGS.dataset == 'imagenet':
|
||||
model = ResNet32Wider(num_filters=196, train=False)
|
||||
else:
|
||||
model = ResNet32Wider(num_filters=256, train=False)
|
||||
else:
|
||||
model = ResNet32(num_filters=hidden_dim)
|
||||
|
||||
if FLAGS.task == 'latent':
|
||||
model = DspritesNet()
|
||||
|
||||
weights = model.construct_weights('context_{}'.format(0))
|
||||
|
||||
total_parameters = 0
|
||||
for variable in tf.trainable_variables():
|
||||
# shape is an array of tf.Dimension
|
||||
shape = variable.get_shape()
|
||||
variable_parameters = 1
|
||||
for dim in shape:
|
||||
variable_parameters *= dim.value
|
||||
total_parameters += variable_parameters
|
||||
print("Model has a total of {} parameters".format(total_parameters))
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.InteractiveSession()
|
||||
|
||||
if FLAGS.task == 'latent':
|
||||
X = tf.placeholder(shape=(None, 64, 64), dtype = tf.float32)
|
||||
else:
|
||||
X = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
Y = tf.placeholder(shape=(None, 10), dtype = tf.float32)
|
||||
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
|
||||
elif FLAGS.dataset == "imagenet":
|
||||
Y = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
|
||||
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
|
||||
|
||||
target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}
|
||||
|
||||
if FLAGS.task == 'label':
|
||||
construct_label(weights, X, Y, Y_GT, model, target_vars)
|
||||
elif FLAGS.task == 'labelfinetune':
|
||||
construct_finetune_label(weights, X, Y, Y_GT, model, target_vars, )
|
||||
elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
|
||||
construct_energy(weights, X, Y, Y_GT, model, target_vars)
|
||||
elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
|
||||
construct_steps(weights, X, Y_GT, model, target_vars)
|
||||
elif FLAGS.task == 'latent':
|
||||
construct_latent(weights, X, Y_GT, model, target_vars)
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
saver = loader = tf.train.Saver(max_to_keep=10)
|
||||
savedir = osp.join('cachedir', FLAGS.exp)
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
if not osp.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
|
||||
initialize()
|
||||
if FLAGS.resume_iter != -1:
|
||||
model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
|
||||
if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
|
||||
optimistic_restore(sess, model_file)
|
||||
# saver.restore(sess, model_file)
|
||||
else:
|
||||
# optimistic_restore(sess, model_file)
|
||||
saver.restore(sess, model_file)
|
||||
|
||||
if FLAGS.task == 'label':
|
||||
if FLAGS.labelgrid:
|
||||
vals = []
|
||||
if FLAGS.lnorm == -1:
|
||||
for i in range(31):
|
||||
accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i)
|
||||
vals.append(accuracies)
|
||||
elif FLAGS.lnorm == 2:
|
||||
for i in range(0, 100, 5):
|
||||
accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i)
|
||||
vals.append(accuracies)
|
||||
|
||||
np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
|
||||
else:
|
||||
label(dataloader, test_dataloader, target_vars, sess)
|
||||
elif FLAGS.task == 'labelfinetune':
|
||||
labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val)
|
||||
elif FLAGS.task == 'energyeval':
|
||||
energyeval(dataloader, test_dataloader, target_vars, sess)
|
||||
elif FLAGS.task == 'mixenergy':
|
||||
energyevalmix(dataloader, test_dataloader, target_vars, sess)
|
||||
elif FLAGS.task == 'anticorrupt':
|
||||
anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
|
||||
elif FLAGS.task == 'boxcorrupt':
|
||||
# boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
|
||||
boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess)
|
||||
elif FLAGS.task == 'crossclass':
|
||||
crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
|
||||
elif FLAGS.task == 'cycleclass':
|
||||
cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
|
||||
elif FLAGS.task == 'democlass':
|
||||
democlass(test_dataloader, weights, model, target_vars, logdir, sess)
|
||||
elif FLAGS.task == 'nearestneighbor':
|
||||
# print(dir(dataset))
|
||||
# print(type(dataset))
|
||||
nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir)
|
||||
elif FLAGS.task == 'latent':
|
||||
latent(test_dataloader, weights, model, target_vars, sess)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
292
EBMs/fid.py
Normal file
292
EBMs/fid.py
Normal file
|
@ -0,0 +1,292 @@
|
|||
#!/usr/bin/env python3
|
||||
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
|
||||
The FID metric calculates the distance between two distributions of images.
|
||||
Typically, we have summary statistics (mean & covariance matrix) of one
|
||||
of these distributions, while the 2nd distribution is given by a GAN.
|
||||
|
||||
When run as a stand-alone program, it compares the distribution of
|
||||
images that are stored as PNG/JPEG at a specified location with a
|
||||
distribution given by summary statistics (in pickle format).
|
||||
|
||||
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
||||
the pool_3 layer of the inception net for generated samples and real world
|
||||
samples respectivly.
|
||||
|
||||
See --help to see further details.
|
||||
'''
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import numpy as np
|
||||
import os
|
||||
import gzip, pickle
|
||||
import tensorflow as tf
|
||||
from scipy.misc import imread
|
||||
from scipy import linalg
|
||||
import pathlib
|
||||
import urllib
|
||||
import tarfile
|
||||
import warnings
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
pool3 = None
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
def get_fid_score(images, images_gt):
|
||||
images = np.stack(images, 0)
|
||||
images_gt = np.stack(images_gt, 0)
|
||||
|
||||
with tf.Session() as sess:
|
||||
m1, s1 = calculate_activation_statistics(images, sess)
|
||||
m2, s2 = calculate_activation_statistics(images_gt, sess)
|
||||
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
||||
|
||||
print("Obtained fid value of {}".format(fid_value))
|
||||
return fid_value
|
||||
|
||||
|
||||
def create_inception_graph(pth):
|
||||
"""Creates a graph from saved GraphDef file."""
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile( pth, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString( f.read())
|
||||
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# code for handling inception net derived from
|
||||
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
|
||||
def _get_inception_layer(sess):
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer. """
|
||||
layername = 'FID_Inception_Net/pool_3:0'
|
||||
pool3 = sess.graph.get_tensor_by_name(layername)
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_activations(images, sess, batch_size=50, verbose=False):
|
||||
"""Calculates the activations of the pool_3 layer for all images.
|
||||
|
||||
Params:
|
||||
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
|
||||
must lie between 0 and 256.
|
||||
-- sess : current session
|
||||
-- batch_size : the images numpy array is split into batches with batch size
|
||||
batch_size. A reasonable batch size depends on the disposable hardware.
|
||||
-- verbose : If set to True and parameter out_step is given, the number of calculated
|
||||
batches is reported.
|
||||
Returns:
|
||||
-- A numpy array of dimension (num images, 2048) that contains the
|
||||
activations of the given tensor when feeding inception with the query tensor.
|
||||
"""
|
||||
# inception_layer = _get_inception_layer(sess)
|
||||
d0 = images.shape[0]
|
||||
if batch_size > d0:
|
||||
print("warning: batch size is bigger than the data size. setting batch size to data size")
|
||||
batch_size = d0
|
||||
n_batches = d0//batch_size
|
||||
n_used_imgs = n_batches*batch_size
|
||||
pred_arr = np.empty((n_used_imgs,2048))
|
||||
for i in range(n_batches):
|
||||
if verbose:
|
||||
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
|
||||
start = i*batch_size
|
||||
end = start + batch_size
|
||||
batch = images[start:end]
|
||||
pred = sess.run(pool3, {'ExpandDims:0': batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size,-1)
|
||||
if verbose:
|
||||
print(" done")
|
||||
return pred_arr
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
"""Numpy implementation of the Frechet Distance.
|
||||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
||||
and X_2 ~ N(mu_2, C_2) is
|
||||
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||
Stable version by Dougal J. Sutherland.
|
||||
|
||||
Params:
|
||||
-- mu1 : Numpy array containing the activations of the pool_3 layer of the
|
||||
inception net ( like returned by the function 'get_predictions')
|
||||
for generated samples.
|
||||
-- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
|
||||
on an representive data set.
|
||||
-- sigma1: The covariance matrix over activations of the pool_3 layer for
|
||||
generated samples.
|
||||
-- sigma2: The covariance matrix over activations of the pool_3 layer,
|
||||
precalcualted on an representive data set.
|
||||
|
||||
Returns:
|
||||
-- : The Frechet Distance.
|
||||
"""
|
||||
|
||||
mu1 = np.atleast_1d(mu1)
|
||||
mu2 = np.atleast_1d(mu2)
|
||||
|
||||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
||||
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
||||
# numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError("Imaginary component {}".format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
||||
"""Calculation of the statistics used by the FID.
|
||||
Params:
|
||||
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
|
||||
must lie between 0 and 255.
|
||||
-- sess : current session
|
||||
-- batch_size : the images numpy array is split into batches with batch size
|
||||
batch_size. A reasonable batch size depends on the available hardware.
|
||||
-- verbose : If set to True and parameter out_step is given, the number of calculated
|
||||
batches is reported.
|
||||
Returns:
|
||||
-- mu : The mean over samples of the activations of the pool_3 layer of
|
||||
the incption model.
|
||||
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
||||
the incption model.
|
||||
"""
|
||||
act = get_activations(images, sess, batch_size, verbose)
|
||||
mu = np.mean(act, axis=0)
|
||||
sigma = np.cov(act, rowvar=False)
|
||||
return mu, sigma
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# The following functions aren't needed for calculating the FID
|
||||
# they're just here to make this module work as a stand-alone script
|
||||
# for calculating FID scores
|
||||
#-------------------------------------------------------------------------------
|
||||
def check_or_download_inception(inception_path):
|
||||
''' Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present. '''
|
||||
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
if inception_path is None:
|
||||
inception_path = '/tmp'
|
||||
inception_path = pathlib.Path(inception_path)
|
||||
model_file = inception_path / 'classify_image_graph_def.pb'
|
||||
if not model_file.exists():
|
||||
print("Downloading Inception model")
|
||||
from urllib import request
|
||||
import tarfile
|
||||
fn, _ = request.urlretrieve(INCEPTION_URL)
|
||||
with tarfile.open(fn, mode='r') as f:
|
||||
f.extract('classify_image_graph_def.pb', str(model_file.parent))
|
||||
return str(model_file)
|
||||
|
||||
|
||||
def _handle_path(path, sess):
|
||||
if path.endswith('.npz'):
|
||||
f = np.load(path)
|
||||
m, s = f['mu'][:], f['sigma'][:]
|
||||
f.close()
|
||||
else:
|
||||
path = pathlib.Path(path)
|
||||
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
||||
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
|
||||
m, s = calculate_activation_statistics(x, sess)
|
||||
return m, s
|
||||
|
||||
|
||||
def calculate_fid_given_paths(paths, inception_path):
|
||||
''' Calculates the FID of two paths. '''
|
||||
inception_path = check_or_download_inception(inception_path)
|
||||
|
||||
for p in paths:
|
||||
if not os.path.exists(p):
|
||||
raise RuntimeError("Invalid path: %s" % p)
|
||||
|
||||
create_inception_graph(str(inception_path))
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
m1, s1 = _handle_path(paths[0], sess)
|
||||
m2, s2 = _handle_path(paths[1], sess)
|
||||
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
||||
return fid_value
|
||||
|
||||
|
||||
def _init_inception():
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
|
||||
|
||||
if pool3 is None:
|
||||
_init_inception()
|
129
EBMs/hmc.py
Normal file
129
EBMs/hmc.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.platform import flags
|
||||
flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def kinetic_energy(velocity):
|
||||
"""Kinetic energy of the current velocity (assuming a standard Gaussian)
|
||||
(x dot x) / 2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
velocity : tf.Variable
|
||||
Vector of current velocity
|
||||
|
||||
Returns
|
||||
-------
|
||||
kinetic_energy : float
|
||||
"""
|
||||
return 0.5 * tf.square(velocity)
|
||||
|
||||
def hamiltonian(position, velocity, energy_function):
|
||||
"""Computes the Hamiltonian of the current position, velocity pair
|
||||
|
||||
H = U(x) + K(v)
|
||||
|
||||
U is the potential energy and is = -log_posterior(x)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : tf.Variable
|
||||
Position or state vector x (sample from the target distribution)
|
||||
velocity : tf.Variable
|
||||
Auxiliary velocity variable
|
||||
energy_function
|
||||
Function from state to position to 'energy'
|
||||
= -log_posterior
|
||||
|
||||
Returns
|
||||
-------
|
||||
hamitonian : float
|
||||
"""
|
||||
batch_size = tf.shape(velocity)[0]
|
||||
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
|
||||
return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
|
||||
|
||||
def leapfrog_step(x0,
|
||||
v0,
|
||||
neg_log_posterior,
|
||||
step_size,
|
||||
num_steps):
|
||||
|
||||
# Start by updating the velocity a half-step
|
||||
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
|
||||
|
||||
# Initalize x to be the first step
|
||||
x = x0 + step_size * v
|
||||
|
||||
for i in range(num_steps):
|
||||
# Compute gradient of the log-posterior with respect to x
|
||||
gradient = tf.gradients(neg_log_posterior(x), x)[0]
|
||||
|
||||
# Update velocity
|
||||
v = v - step_size * gradient
|
||||
|
||||
# x_clip = tf.clip_by_value(x, 0.0, 1.0)
|
||||
# x = x_clip
|
||||
# v_mask = 1 - 2 * tf.abs(tf.sign(x - x_clip))
|
||||
# v = v * v_mask
|
||||
|
||||
# Update x
|
||||
x = x + step_size * v
|
||||
|
||||
# x = tf.clip_by_value(x, -0.01, 1.01)
|
||||
|
||||
# x = tf.Print(x, [tf.reduce_min(x), tf.reduce_max(x), tf.reduce_mean(x)])
|
||||
|
||||
# Do a final update of the velocity for a half step
|
||||
v = v - 0.5 * step_size * tf.gradients(neg_log_posterior(x), x)[0]
|
||||
|
||||
# return new proposal state
|
||||
return x, v
|
||||
|
||||
def hmc(initial_x,
|
||||
step_size,
|
||||
num_steps,
|
||||
neg_log_posterior):
|
||||
"""Summary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial_x : tf.Variable
|
||||
Initial sample x ~ p
|
||||
step_size : float
|
||||
Step-size in Hamiltonian simulation
|
||||
num_steps : int
|
||||
Number of steps to take in Hamiltonian simulation
|
||||
neg_log_posterior : str
|
||||
Negative log posterior (unnormalized) for the target distribution
|
||||
|
||||
Returns
|
||||
-------
|
||||
sample :
|
||||
Sample ~ target distribution
|
||||
"""
|
||||
|
||||
v0 = tf.random_normal(tf.shape(initial_x))
|
||||
x, v = leapfrog_step(initial_x,
|
||||
v0,
|
||||
step_size=step_size,
|
||||
num_steps=num_steps,
|
||||
neg_log_posterior=neg_log_posterior)
|
||||
|
||||
orig = hamiltonian(initial_x, v0, neg_log_posterior)
|
||||
current = hamiltonian(x, v, neg_log_posterior)
|
||||
|
||||
prob_accept = tf.exp(orig - current)
|
||||
|
||||
if FLAGS.proposal_debug:
|
||||
prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
|
||||
|
||||
uniform = tf.random_uniform(tf.shape(prob_accept))
|
||||
keep_mask = (prob_accept > uniform)
|
||||
# print(keep_mask.get_shape())
|
||||
|
||||
x_new = tf.where(keep_mask, x, initial_x)
|
||||
return x_new
|
73
EBMs/imagenet_demo.py
Normal file
73
EBMs/imagenet_demo.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
from models import ResNet128
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from tensorflow.python.platform import flags
|
||||
import tensorflow as tf
|
||||
import imageio
|
||||
from utils import optimistic_restore
|
||||
|
||||
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
|
||||
flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
|
||||
flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
|
||||
flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
|
||||
flags.DEFINE_bool('cclass', True, 'conditional models')
|
||||
flags.DEFINE_bool('use_attention', False, 'using attention')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def rescale_im(im):
|
||||
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ResNet128(num_filters=64)
|
||||
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
||||
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
|
||||
sess = tf.InteractiveSession()
|
||||
weights = model.construct_weights("context_0")
|
||||
|
||||
x_mod = X_NOISE
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
|
||||
mean=0.0,
|
||||
stddev=0.005)
|
||||
|
||||
energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
|
||||
reuse=True, stop_at_grad=False, stop_batch=True)
|
||||
|
||||
x_grad = tf.gradients(energy_noise, [x_mod])[0]
|
||||
energy_noise_old = energy_noise
|
||||
|
||||
lr = FLAGS.step_lr
|
||||
|
||||
x_last = x_mod - (lr) * x_grad
|
||||
|
||||
x_mod = x_last
|
||||
x_mod = tf.clip_by_value(x_mod, 0, 1)
|
||||
x_output = x_mod
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
saver = loader = tf.train.Saver()
|
||||
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
saver.restore(sess, model_file)
|
||||
|
||||
lx = np.random.permutation(1000)[:16]
|
||||
ims = []
|
||||
|
||||
# What to initialize sampling with.
|
||||
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
labels = np.eye(1000)[lx]
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
|
||||
ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
|
||||
|
||||
imageio.mimwrite('sample.gif', ims)
|
||||
|
||||
|
337
EBMs/imagenet_preprocessing.py
Normal file
337
EBMs/imagenet_preprocessing.py
Normal file
|
@ -0,0 +1,337 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Image pre-processing utilities.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
IMAGE_DEPTH = 3 # color images
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# _R_MEAN = 123.68
|
||||
# _G_MEAN = 116.78
|
||||
# _B_MEAN = 103.94
|
||||
# _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
|
||||
_CHANNEL_MEANS = [0.0, 0.0, 0.0]
|
||||
|
||||
# The lower bound for the smallest side of the image for aspect-preserving
|
||||
# resizing. For example, if an image is 500 x 1000, it will be resized to
|
||||
# _RESIZE_MIN x (_RESIZE_MIN * 2).
|
||||
_RESIZE_MIN = 128
|
||||
|
||||
|
||||
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
|
||||
"""Crops the given image to a random part of the image, and randomly flips.
|
||||
|
||||
We use the fused decode_and_crop op, which performs better than the two ops
|
||||
used separately in series, but note that this requires that the image be
|
||||
passed in as an un-decoded string Tensor.
|
||||
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
|
||||
"""
|
||||
# A large fraction of image datasets contain a human-annotated bounding box
|
||||
# delineating the region of the image containing the object of interest. We
|
||||
# choose to create a new bounding box for the object which is a randomly
|
||||
# distorted version of the human-annotated bounding box that obeys an
|
||||
# allowed range of aspect ratios, sizes and overlap with the human-annotated
|
||||
# bounding box. If no box is supplied, then we assume the bounding box is
|
||||
# the entire image.
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
tf.image.extract_jpeg_shape(image_buffer),
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=[0.75, 1.33],
|
||||
area_range=[0.05, 1.0],
|
||||
max_attempts=100,
|
||||
use_image_if_no_bounding_boxes=True)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
|
||||
# Reassemble the bounding box in the format the crop op requires.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
||||
|
||||
# Use the fused decode and crop op here, which is faster than each in series.
|
||||
cropped = tf.image.decode_and_crop_jpeg(
|
||||
image_buffer, crop_window, channels=num_channels)
|
||||
|
||||
# Flip to add a little more random distortion in.
|
||||
cropped = tf.image.random_flip_left_right(cropped)
|
||||
return cropped
|
||||
|
||||
|
||||
def _central_crop(image, crop_height, crop_width):
|
||||
"""Performs central crops of the given image list.
|
||||
|
||||
Args:
|
||||
image: a 3-D image tensor
|
||||
crop_height: the height of the image following the crop.
|
||||
crop_width: the width of the image following the crop.
|
||||
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
|
||||
amount_to_be_cropped_h = (height - crop_height)
|
||||
crop_top = amount_to_be_cropped_h // 2
|
||||
amount_to_be_cropped_w = (width - crop_width)
|
||||
crop_left = amount_to_be_cropped_w // 2
|
||||
return tf.slice(
|
||||
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
|
||||
|
||||
|
||||
def _mean_image_subtraction(image, means, num_channels):
|
||||
"""Subtracts the given means from each image channel.
|
||||
|
||||
For example:
|
||||
means = [123.68, 116.779, 103.939]
|
||||
image = _mean_image_subtraction(image, means)
|
||||
|
||||
Note that the rank of `image` must be known.
|
||||
|
||||
Args:
|
||||
image: a tensor of size [height, width, C].
|
||||
means: a C-vector of values to subtract from each channel.
|
||||
num_channels: number of color channels in the image that will be distorted.
|
||||
|
||||
Returns:
|
||||
the centered image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rank of `image` is unknown, if `image` has a rank other
|
||||
than three or if the number of channels in `image` doesn't match the
|
||||
number of values in `means`.
|
||||
"""
|
||||
if image.get_shape().ndims != 3:
|
||||
raise ValueError('Input must be of size [height, width, C>0]')
|
||||
|
||||
if len(means) != num_channels:
|
||||
raise ValueError('len(means) must match the number of channels')
|
||||
|
||||
# We have a 1-D tensor of means; convert to 3-D.
|
||||
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
|
||||
|
||||
return image - means
|
||||
|
||||
|
||||
def _smallest_size_at_least(height, width, resize_min):
|
||||
"""Computes new shape with the smallest side equal to `smallest_side`.
|
||||
|
||||
Computes new shape with the smallest side equal to `smallest_side` while
|
||||
preserving the original aspect ratio.
|
||||
|
||||
Args:
|
||||
height: an int32 scalar tensor indicating the current height.
|
||||
width: an int32 scalar tensor indicating the current width.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
|
||||
Returns:
|
||||
new_height: an int32 scalar tensor indicating the new height.
|
||||
new_width: an int32 scalar tensor indicating the new width.
|
||||
"""
|
||||
resize_min = tf.cast(resize_min, tf.float32)
|
||||
|
||||
# Convert to floats to make subsequent calculations go smoothly.
|
||||
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
|
||||
|
||||
smaller_dim = tf.minimum(height, width)
|
||||
scale_ratio = resize_min / smaller_dim
|
||||
|
||||
# Convert back to ints to make heights and widths that TF ops will accept.
|
||||
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
|
||||
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
def _aspect_preserving_resize(image, resize_min):
|
||||
"""Resize images preserving the original aspect ratio.
|
||||
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
|
||||
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
|
||||
|
||||
return _resize_image(image, new_height, new_width)
|
||||
|
||||
|
||||
def _resize_image(image, height, width):
|
||||
"""Simple wrapper around tf.resize_images.
|
||||
|
||||
This is primarily to make sure we use the same `ResizeMethod` and other
|
||||
details each time.
|
||||
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
height: The target height for the resized image.
|
||||
width: The target width for the resized image.
|
||||
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image. The first two
|
||||
dimensions have the shape [height, width].
|
||||
"""
|
||||
return tf.image.resize_images(
|
||||
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
|
||||
align_corners=False)
|
||||
|
||||
|
||||
def preprocess_image(image_buffer, bbox, output_height, output_width,
|
||||
num_channels, is_training=False):
|
||||
"""Preprocesses the given image.
|
||||
|
||||
Preprocessing includes decoding, cropping, and resizing for both training
|
||||
and eval images. Training preprocessing, however, introduces some random
|
||||
distortion of the image to improve accuracy.
|
||||
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
output_height: The height of the image after preprocessing.
|
||||
output_width: The width of the image after preprocessing.
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
is_training: `True` if we're preprocessing the image for training and
|
||||
`False` otherwise.
|
||||
|
||||
Returns:
|
||||
A preprocessed image.
|
||||
"""
|
||||
if is_training:
|
||||
# For training, we want to randomize some of the distortions.
|
||||
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
|
||||
image = _resize_image(image, output_height, output_width)
|
||||
else:
|
||||
# For validation, we want to decode, resize, then just crop the middle.
|
||||
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
|
||||
image = _aspect_preserving_resize(image, _RESIZE_MIN)
|
||||
print(image)
|
||||
image = _central_crop(image, output_height, output_width)
|
||||
|
||||
image.set_shape([output_height, output_width, num_channels])
|
||||
|
||||
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
|
||||
|
||||
|
||||
def parse_example_proto(example_serialized):
|
||||
"""Parses an Example proto containing a training example of an image.
|
||||
|
||||
The output of the build_image_data.py image preprocessing script is a dataset
|
||||
containing serialized Example protocol buffers. Each Example proto contains
|
||||
the following fields:
|
||||
|
||||
image/height: 462
|
||||
image/width: 581
|
||||
image/colorspace: 'RGB'
|
||||
image/channels: 3
|
||||
image/class/label: 615
|
||||
image/class/synset: 'n03623198'
|
||||
image/class/text: 'knee pad'
|
||||
image/object/bbox/xmin: 0.1
|
||||
image/object/bbox/xmax: 0.9
|
||||
image/object/bbox/ymin: 0.2
|
||||
image/object/bbox/ymax: 0.6
|
||||
image/object/bbox/label: 615
|
||||
image/format: 'JPEG'
|
||||
image/filename: 'ILSVRC2012_val_00041207.JPEG'
|
||||
image/encoded: <JPEG encoded string>
|
||||
|
||||
Args:
|
||||
example_serialized: scalar Tensor tf.string containing a serialized
|
||||
Example protocol buffer.
|
||||
|
||||
Returns:
|
||||
image_buffer: Tensor tf.string containing the contents of a JPEG file.
|
||||
label: Tensor tf.int32 containing the label.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
text: Tensor tf.string containing the human-readable label.
|
||||
"""
|
||||
# Dense features in Example proto.
|
||||
feature_map = {
|
||||
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
|
||||
default_value=''),
|
||||
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
|
||||
default_value=-1),
|
||||
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
|
||||
default_value=''),
|
||||
}
|
||||
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
|
||||
# Sparse features in Example proto.
|
||||
feature_map.update(
|
||||
{k: sparse_float32 for k in ['image/object/bbox/xmin',
|
||||
'image/object/bbox/ymin',
|
||||
'image/object/bbox/xmax',
|
||||
'image/object/bbox/ymax']})
|
||||
|
||||
features = tf.parse_single_example(example_serialized, feature_map)
|
||||
label = tf.cast(features['image/class/label'], dtype=tf.int32)
|
||||
|
||||
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
|
||||
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
|
||||
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
|
||||
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
|
||||
|
||||
# Note that we impose an ordering of (y, x) just to make life difficult.
|
||||
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
|
||||
|
||||
# Force the variable number of bounding boxes into the shape
|
||||
# [1, num_boxes, coords].
|
||||
bbox = tf.expand_dims(bbox, 0)
|
||||
bbox = tf.transpose(bbox, [0, 2, 1])
|
||||
|
||||
return features['image/encoded'], label, bbox, features['image/class/text']
|
||||
|
||||
|
||||
class ImagenetPreprocessor:
|
||||
def __init__(self, image_size, dtype, train):
|
||||
self.image_size = image_size
|
||||
self.dtype = dtype
|
||||
self.train = train
|
||||
|
||||
def preprocess(self, image_buffer, bbox):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
|
||||
return tf.cast(image, self.dtype)
|
||||
|
||||
def parse_and_preprocess(self, value):
|
||||
image_buffer, label_index, bbox, _ = parse_example_proto(value)
|
||||
image = self.preprocess(image_buffer, bbox)
|
||||
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
|
||||
return label_index, image
|
||||
|
105
EBMs/inception.py
Normal file
105
EBMs/inception.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from six.moves import urllib
|
||||
import tensorflow as tf
|
||||
import glob
|
||||
import scipy.misc
|
||||
import math
|
||||
import sys
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
softmax = None
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
sess = tf.Session(config=config)
|
||||
|
||||
# Call this function with list of images. Each of elements should be a
|
||||
# numpy array with values ranging from 0 to 255.
|
||||
def get_inception_score(images, splits=10):
|
||||
# For convenience
|
||||
if len(images[0].shape) != 3:
|
||||
return 0, 0
|
||||
|
||||
# Bypassing all the assertions so that we don't end prematuraly'
|
||||
# assert(type(images) == list)
|
||||
# assert(type(images[0]) == np.ndarray)
|
||||
# assert(len(images[0].shape) == 3)
|
||||
# assert(np.max(images[0]) > 10)
|
||||
# assert(np.min(images[0]) >= 0.0)
|
||||
inps = []
|
||||
for img in images:
|
||||
img = img.astype(np.float32)
|
||||
inps.append(np.expand_dims(img, 0))
|
||||
bs = 1
|
||||
preds = []
|
||||
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
|
||||
for i in range(n_batches):
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
|
||||
inp = np.concatenate(inp, 0)
|
||||
pred = sess.run(softmax, {'ExpandDims:0': inp})
|
||||
preds.append(pred)
|
||||
preds = np.concatenate(preds, 0)
|
||||
scores = []
|
||||
for i in range(splits):
|
||||
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return np.mean(scores), np.std(scores)
|
||||
|
||||
# This function is called automatically.
|
||||
def _init_inception():
|
||||
global softmax
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.set_shape(tf.TensorShape(new_shape))
|
||||
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
|
||||
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
|
||||
softmax = tf.nn.softmax(logits)
|
||||
|
||||
if softmax is None:
|
||||
_init_inception()
|
622
EBMs/models.py
Normal file
622
EBMs/models.py
Normal file
|
@ -0,0 +1,622 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow.python.platform import flags
|
||||
import numpy as np
|
||||
from utils import conv_block, get_weight, attention, conv_cond_concat, init_conv_weight, init_attention_weight, init_res_weight, smart_res_block, smart_res_block_optim, init_convt_weight
|
||||
from utils import init_fc_weight, smart_conv_block, smart_fc_block, smart_atten_block, groupsort, smart_convt_block, swish
|
||||
|
||||
flags.DEFINE_bool('swish_act', False, 'use the swish activation for dsprites')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class MnistNet(object):
|
||||
def __init__(self, num_channels=1, num_filters=64):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
self.datasource = FLAGS.datasource
|
||||
|
||||
if FLAGS.cclass:
|
||||
self.label_size = 10
|
||||
else:
|
||||
self.label_size = 0
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
|
||||
dtype = tf.float32
|
||||
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
|
||||
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
|
||||
|
||||
classes = 1
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
init_conv_weight(weights, 'c1_pre', 3, 1, 64)
|
||||
init_conv_weight(weights, 'c1', 4, 64, self.dim_hidden, classes=classes)
|
||||
init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc_dense', 4*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
|
||||
init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
|
||||
|
||||
if FLAGS.cclass:
|
||||
self.label_size = 10
|
||||
else:
|
||||
self.label_size = 0
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, **kwargs):
|
||||
channels = self.channels
|
||||
weights = weights.copy()
|
||||
inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1))
|
||||
|
||||
if FLAGS.swish_act:
|
||||
act = swish
|
||||
else:
|
||||
act = tf.nn.leaky_relu
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
if FLAGS.cclass:
|
||||
label_d = tf.reshape(label, shape=(tf.shape(label)[0], 1, 1, self.label_size))
|
||||
inp = conv_cond_concat(inp, label_d)
|
||||
|
||||
h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
|
||||
h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
|
||||
h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
|
||||
h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=False, extra_bias=False, activation=act)
|
||||
|
||||
h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])])
|
||||
h6 = act(smart_fc_block(h5, weights, reuse, 'fc_dense'))
|
||||
hidden6 = smart_fc_block(h6, weights, reuse, 'fc5')
|
||||
|
||||
return hidden6
|
||||
|
||||
|
||||
class DspritesNet(object):
|
||||
def __init__(self, num_channels=1, num_filters=64, cond_size=False, cond_shape=False, cond_pos=False,
|
||||
cond_rot=False, label_size=1):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
self.img_size = 64
|
||||
self.label_size = label_size
|
||||
|
||||
if FLAGS.cclass:
|
||||
self.label_size = 3
|
||||
|
||||
try:
|
||||
if FLAGS.dshape_only:
|
||||
self.label_size = 3
|
||||
|
||||
if FLAGS.dpos_only:
|
||||
self.label_size = 2
|
||||
|
||||
if FLAGS.dsize_only:
|
||||
self.label_size = 1
|
||||
|
||||
if FLAGS.drot_only:
|
||||
self.label_size = 2
|
||||
except:
|
||||
pass
|
||||
|
||||
if cond_size:
|
||||
self.label_size = 1
|
||||
|
||||
if cond_shape:
|
||||
self.label_size = 3
|
||||
|
||||
if cond_pos:
|
||||
self.label_size = 2
|
||||
|
||||
if cond_rot:
|
||||
self.label_size = 2
|
||||
|
||||
self.cond_size = cond_size
|
||||
self.cond_shape = cond_shape
|
||||
self.cond_pos = cond_pos
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
|
||||
dtype = tf.float32
|
||||
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
|
||||
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
|
||||
k = 5
|
||||
classes = self.label_size
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
init_conv_weight(weights, 'c1_pre', 3, 1, 32)
|
||||
init_conv_weight(weights, 'c1', 4, 32, self.dim_hidden, classes=classes)
|
||||
init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_conv_weight(weights, 'c4', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc_dense', 2*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
|
||||
init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, return_logit=False):
|
||||
channels = self.channels
|
||||
batch_size = tf.shape(inp)[0]
|
||||
|
||||
inp = tf.reshape(inp, (batch_size, 64, 64, 1))
|
||||
|
||||
if FLAGS.swish_act:
|
||||
act = swish
|
||||
else:
|
||||
act = tf.nn.leaky_relu
|
||||
|
||||
if not FLAGS.cclass:
|
||||
label = None
|
||||
|
||||
weights = weights.copy()
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
|
||||
h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
|
||||
h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
|
||||
h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=True, extra_bias=True, activation=act)
|
||||
h5 = smart_conv_block(h4, weights, reuse, 'c4', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
|
||||
|
||||
hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1))
|
||||
hidden7 = act(smart_fc_block(hidden6, weights, reuse, 'fc_dense'))
|
||||
energy = smart_fc_block(hidden7, weights, reuse, 'fc5')
|
||||
|
||||
if return_logit:
|
||||
return hidden7
|
||||
else:
|
||||
return energy
|
||||
|
||||
|
||||
|
||||
class ResNet32(object):
|
||||
def __init__(self, num_channels=3, num_filters=128):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
self.groupsort = groupsort()
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
dtype = tf.float32
|
||||
|
||||
if FLAGS.cclass:
|
||||
classes = 10
|
||||
else:
|
||||
classes = 1
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
# First block
|
||||
init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
|
||||
init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_2', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_3', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
|
||||
init_fc_weight(weights, 'fc5', 2*self.dim_hidden , 1, spec_norm=False)
|
||||
|
||||
init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
|
||||
weights = weights.copy()
|
||||
batch = tf.shape(inp)[0]
|
||||
|
||||
act = tf.nn.leaky_relu
|
||||
|
||||
if not FLAGS.cclass:
|
||||
label = None
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
# Make sure gradients are modified a bit
|
||||
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
|
||||
|
||||
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act)
|
||||
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, act=act)
|
||||
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, label=label, act=act)
|
||||
|
||||
if FLAGS.use_attention:
|
||||
hidden4 = smart_atten_block(hidden3, weights, reuse, 'atten', stop_at_grad=stop_at_grad, label=label)
|
||||
else:
|
||||
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act)
|
||||
|
||||
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=False, label=label, act=act)
|
||||
compact = hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
hidden6 = tf.nn.relu(hidden6)
|
||||
hidden5 = tf.reduce_sum(hidden6, [1, 2])
|
||||
|
||||
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
|
||||
|
||||
energy = hidden6
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
class ResNet32Large(object):
|
||||
def __init__(self, num_channels=3, num_filters=128, train=False):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
self.dropout = train
|
||||
self.train = train
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
dtype = tf.float32
|
||||
|
||||
if FLAGS.cclass:
|
||||
classes = 10
|
||||
else:
|
||||
classes = 1
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
# First block
|
||||
init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
|
||||
init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
|
||||
|
||||
init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden, trainable_gamma=True)
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
|
||||
weights = weights.copy()
|
||||
batch = tf.shape(inp)[0]
|
||||
|
||||
if not FLAGS.cclass:
|
||||
label = None
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
# Make sure gradients are modified a bit
|
||||
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
|
||||
|
||||
dropout = self.dropout
|
||||
train = self.train
|
||||
|
||||
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, dropout=dropout, train=train)
|
||||
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
|
||||
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
|
||||
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
|
||||
|
||||
if FLAGS.use_attention:
|
||||
hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
|
||||
else:
|
||||
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
|
||||
|
||||
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
|
||||
|
||||
hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
|
||||
hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
|
||||
|
||||
compact = hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
|
||||
|
||||
if FLAGS.cclass:
|
||||
hidden6 = tf.nn.leaky_relu(hidden9)
|
||||
else:
|
||||
hidden6 = tf.nn.relu(hidden9)
|
||||
hidden5 = tf.reduce_sum(hidden6, [1, 2])
|
||||
|
||||
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
|
||||
|
||||
energy = hidden6
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
class ResNet32Wider(object):
|
||||
def __init__(self, num_channels=3, num_filters=128, train=False):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
self.dropout = train
|
||||
self.train = train
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
dtype = tf.float32
|
||||
|
||||
if FLAGS.cclass and FLAGS.dataset == "cifar10":
|
||||
classes = 10
|
||||
elif FLAGS.cclass and FLAGS.dataset == "imagenet":
|
||||
classes = 1000
|
||||
else:
|
||||
classes = 1
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
# First block
|
||||
init_conv_weight(weights, 'c1_pre', 3, self.channels, 128)
|
||||
init_res_weight(weights, 'res_optim', 3, 128, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
|
||||
|
||||
init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
|
||||
weights = weights.copy()
|
||||
batch = tf.shape(inp)[0]
|
||||
|
||||
if not FLAGS.cclass:
|
||||
label = None
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
if FLAGS.swish_act:
|
||||
act = swish
|
||||
else:
|
||||
act = tf.nn.leaky_relu
|
||||
|
||||
# Make sure gradients are modified a bit
|
||||
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
|
||||
dropout = self.dropout
|
||||
train = self.train
|
||||
|
||||
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=True, label=label, dropout=dropout, train=train)
|
||||
|
||||
if FLAGS.use_attention:
|
||||
hidden2 = smart_atten_block(hidden1, weights, reuse, 'atten', train=train, dropout=dropout, stop_at_grad=stop_at_grad)
|
||||
else:
|
||||
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
|
||||
|
||||
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
|
||||
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
|
||||
|
||||
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
|
||||
|
||||
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
|
||||
|
||||
hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
|
||||
hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
|
||||
|
||||
hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
|
||||
|
||||
if FLAGS.swish_act:
|
||||
hidden6 = act(hidden9)
|
||||
else:
|
||||
hidden6 = tf.nn.relu(hidden9)
|
||||
|
||||
hidden5 = tf.reduce_sum(hidden6, [1, 2])
|
||||
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
|
||||
energy = hidden6
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
class ResNet32Larger(object):
|
||||
def __init__(self, num_channels=3, num_filters=128):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
dtype = tf.float32
|
||||
|
||||
if FLAGS.cclass:
|
||||
classes = 10
|
||||
else:
|
||||
classes = 1
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
# First block
|
||||
init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
|
||||
init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_2a', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_2b', 3, self.dim_hidden, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5a', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5b', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_8a', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_8b', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
|
||||
init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
|
||||
|
||||
init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
|
||||
weights = weights.copy()
|
||||
batch = tf.shape(inp)[0]
|
||||
|
||||
if not FLAGS.cclass:
|
||||
label = None
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
# Make sure gradients are modified a bit
|
||||
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
|
||||
|
||||
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label)
|
||||
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
|
||||
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
|
||||
hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2a', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
|
||||
hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2b', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
|
||||
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label)
|
||||
|
||||
if FLAGS.use_attention:
|
||||
hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
|
||||
else:
|
||||
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
|
||||
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
|
||||
hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label)
|
||||
hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
compact = hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
|
||||
|
||||
if FLAGS.cclass:
|
||||
hidden6 = tf.nn.leaky_relu(hidden9)
|
||||
else:
|
||||
hidden6 = tf.nn.relu(hidden9)
|
||||
hidden5 = tf.reduce_sum(hidden6, [1, 2])
|
||||
|
||||
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
|
||||
|
||||
energy = hidden6
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
class ResNet128(object):
|
||||
"""Construct the convolutional network specified in MAML"""
|
||||
|
||||
def __init__(self, num_channels=3, num_filters=64, train=False):
|
||||
|
||||
self.channels = num_channels
|
||||
self.dim_hidden = num_filters
|
||||
self.dropout = train
|
||||
self.train = train
|
||||
|
||||
def construct_weights(self, scope=''):
|
||||
weights = {}
|
||||
dtype = tf.float32
|
||||
|
||||
classes = 1000
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
# First block
|
||||
init_conv_weight(weights, 'c1_pre', 3, self.channels, 64)
|
||||
init_res_weight(weights, 'res_optim', 3, 64, self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 8*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_9', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
|
||||
init_res_weight(weights, 'res_10', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
|
||||
init_fc_weight(weights, 'fc5', 8*self.dim_hidden , 1, spec_norm=False)
|
||||
|
||||
|
||||
init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2., trainable_gamma=True)
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
|
||||
weights = weights.copy()
|
||||
batch = tf.shape(inp)[0]
|
||||
|
||||
if not FLAGS.cclass:
|
||||
label = None
|
||||
|
||||
|
||||
if stop_grad:
|
||||
for k, v in weights.items():
|
||||
if type(v) == dict:
|
||||
v = v.copy()
|
||||
weights[k] = v
|
||||
for k_sub, v_sub in v.items():
|
||||
v[k_sub] = tf.stop_gradient(v_sub)
|
||||
else:
|
||||
weights[k] = tf.stop_gradient(v)
|
||||
|
||||
if FLAGS.swish_act:
|
||||
act = swish
|
||||
else:
|
||||
act = tf.nn.leaky_relu
|
||||
|
||||
dropout = self.dropout
|
||||
train = self.train
|
||||
|
||||
# Make sure gradients are modified a bit
|
||||
inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
|
||||
hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False)
|
||||
|
||||
if FLAGS.use_attention:
|
||||
hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
|
||||
|
||||
hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
|
||||
hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
|
||||
hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True)
|
||||
hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False)
|
||||
hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False)
|
||||
|
||||
if FLAGS.swish_act:
|
||||
hidden6 = act(hidden6)
|
||||
else:
|
||||
hidden6 = tf.nn.relu(hidden6)
|
||||
|
||||
hidden5 = tf.reduce_sum(hidden6, [1, 2])
|
||||
hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
|
||||
energy = hidden6
|
||||
|
||||
return energy
|
18
EBMs/requirements.txt
Normal file
18
EBMs/requirements.txt
Normal file
|
@ -0,0 +1,18 @@
|
|||
scipy==1.10.0
|
||||
horovod==0.24.0
|
||||
torch==1.13.1
|
||||
torchvision==0.6.0
|
||||
six==1.11.0
|
||||
imageio==2.8.0
|
||||
tqdm==4.46.0
|
||||
matplotlib==3.2.1
|
||||
mpi4py==3.0.3
|
||||
numpy==1.22.0
|
||||
Pillow==10.0.1
|
||||
baselines==0.1.5
|
||||
scikit-image==0.14.2
|
||||
scikit_learn
|
||||
tensorflow==2.11.1
|
||||
cloudpickle==1.3.0
|
||||
Cython==0.29.17
|
||||
mujoco-py==1.50.1.68
|
333
EBMs/test_inception.py
Normal file
333
EBMs/test_inception.py
Normal file
|
@ -0,0 +1,333 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tensorflow.python.platform import flags
|
||||
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
|
||||
import os.path as osp
|
||||
import os
|
||||
from utils import optimistic_restore, remap_restore, optimistic_remap_restore
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from scipy.misc import imsave
|
||||
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
hvd.init()
|
||||
|
||||
from inception import get_inception_score
|
||||
from fid import get_fid_score
|
||||
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_bool('cclass', False, 'whether to condition on class')
|
||||
|
||||
# Architecture settings
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
|
||||
flags.DEFINE_integer('batch_size', 512, 'batch size')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
|
||||
flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
|
||||
flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
|
||||
flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
|
||||
flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
|
||||
flags.DEFINE_integer('idx', 0, 'save index')
|
||||
flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
|
||||
flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
|
||||
flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
|
||||
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('single', False, 'single ')
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
class InceptionReplayBuffer(object):
|
||||
def __init__(self, size):
|
||||
"""Create Replay buffer.
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
"""
|
||||
self._storage = []
|
||||
self._label_storage = []
|
||||
self._maxsize = size
|
||||
self._next_idx = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
def add(self, ims, labels):
|
||||
batch_size = ims.shape[0]
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.extend(list(ims))
|
||||
self._label_storage.extend(list(labels))
|
||||
else:
|
||||
if batch_size + self._next_idx < self._maxsize:
|
||||
self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
|
||||
self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
|
||||
else:
|
||||
split_idx = self._maxsize - self._next_idx
|
||||
self._storage[self._next_idx:] = list(ims)[:split_idx]
|
||||
self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
|
||||
self._label_storage[self._next_idx:] = list(labels)[:split_idx]
|
||||
self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
|
||||
|
||||
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
|
||||
|
||||
def _encode_sample(self, idxes):
|
||||
ims = []
|
||||
labels = []
|
||||
for i in idxes:
|
||||
ims.append(self._storage[i])
|
||||
labels.append(self._label_storage[i])
|
||||
return np.array(ims), np.array(labels)
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""Sample a batch of experiences.
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
"""
|
||||
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
|
||||
return self._encode_sample(idxes), idxes
|
||||
|
||||
def set_elms(self, idxes, data, labels):
|
||||
for i, ix in enumerate(idxes):
|
||||
self._storage[ix] = data[i]
|
||||
self._label_storage[ix] = labels[i]
|
||||
|
||||
|
||||
def rescale_im(im):
|
||||
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
||||
|
||||
def compute_inception(sess, target_vars):
|
||||
X_START = target_vars['X_START']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
X_finals = target_vars['X_finals']
|
||||
NOISE_SCALE = target_vars['NOISE_SCALE']
|
||||
energy_noise = target_vars['energy_noise']
|
||||
|
||||
size = FLAGS.im_number
|
||||
num_steps = size // 1000
|
||||
|
||||
images = []
|
||||
test_ims = []
|
||||
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
test_dataset = Cifar10(full=True, noise=False)
|
||||
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
|
||||
test_dataset = Imagenet(train=False)
|
||||
|
||||
if FLAGS.dataset != "imagenetfull":
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
|
||||
else:
|
||||
test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
data = data.numpy()
|
||||
test_ims.extend(list(rescale_im(data)))
|
||||
|
||||
if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
|
||||
test_ims = test_ims[:60000]
|
||||
break
|
||||
|
||||
|
||||
# n = min(len(images), len(test_ims))
|
||||
print(len(test_ims))
|
||||
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
|
||||
# print("Base FID of score {}".format(fid))
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
classes = 10
|
||||
else:
|
||||
classes = 1000
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
n = 128
|
||||
else:
|
||||
n = 32
|
||||
|
||||
for j in range(num_steps):
|
||||
itr = int(1000 / 500 * FLAGS.repeat_scale)
|
||||
data_buffer = InceptionReplayBuffer(1000)
|
||||
curr_index = 0
|
||||
|
||||
identity = np.eye(classes)
|
||||
|
||||
for i in tqdm(range(itr)):
|
||||
model_index = curr_index % len(X_finals)
|
||||
x_final = X_finals[model_index]
|
||||
|
||||
noise_scale = [1]
|
||||
if len(data_buffer) < 1000:
|
||||
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
||||
label = np.random.randint(0, classes, (FLAGS.batch_size))
|
||||
label = identity[label]
|
||||
x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
|
||||
data_buffer.add(x_new, label)
|
||||
else:
|
||||
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
|
||||
keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
|
||||
label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
|
||||
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
|
||||
label_corrupt = identity[label_corrupt]
|
||||
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
||||
|
||||
if i < itr - FLAGS.nomix:
|
||||
x_init[keep_mask] = x_init_corrupt[keep_mask]
|
||||
label[label_keep_mask] = label_corrupt[label_keep_mask]
|
||||
# else:
|
||||
# noise_scale = [0.7]
|
||||
|
||||
x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
|
||||
data_buffer.set_elms(idx, x_new, label)
|
||||
|
||||
if FLAGS.im_number != 50000:
|
||||
print(np.mean(e_noise), np.std(e_noise))
|
||||
|
||||
curr_index += 1
|
||||
|
||||
ims = np.array(data_buffer._storage[:1000])
|
||||
ims = rescale_im(ims)
|
||||
|
||||
images.extend(list(ims))
|
||||
|
||||
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
|
||||
|
||||
ims = ims[:100]
|
||||
|
||||
if FLAGS.dataset != "imagenetfull":
|
||||
im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
|
||||
else:
|
||||
im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
|
||||
imsave(saveim, im_panel)
|
||||
|
||||
print("Saved image!!!!")
|
||||
splits = max(1, len(images) // 5000)
|
||||
score, std = get_inception_score(images, splits=splits)
|
||||
print("Inception score of {} with std of {}".format(score, std))
|
||||
|
||||
# FID score
|
||||
# n = min(len(images), len(test_ims))
|
||||
fid = get_fid_score(images, test_ims)
|
||||
print("FID of score {}".format(fid))
|
||||
|
||||
|
||||
|
||||
|
||||
def main(model_list):
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
model = ResNet128(num_filters=64)
|
||||
elif FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=128)
|
||||
elif FLAGS.larger_model:
|
||||
model = ResNet32Larger(num_filters=hidden_dim)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(num_filters=256, train=False)
|
||||
else:
|
||||
model = ResNet32(num_filters=128)
|
||||
|
||||
# config = tf.ConfigProto()
|
||||
sess = tf.InteractiveSession()
|
||||
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
weights = []
|
||||
|
||||
for i, model_num in enumerate(model_list):
|
||||
weight = model.construct_weights('context_{}'.format(i))
|
||||
initialize()
|
||||
save_file = osp.join(logdir, 'model_{}'.format(model_num))
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
|
||||
v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
|
||||
saver = tf.train.Saver(v_map)
|
||||
try:
|
||||
saver.restore(sess, save_file)
|
||||
except:
|
||||
optimistic_remap_restore(sess, save_file, i)
|
||||
weights.append(weight)
|
||||
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
|
||||
else:
|
||||
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
|
||||
else:
|
||||
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
|
||||
|
||||
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
|
||||
|
||||
X_finals = []
|
||||
|
||||
|
||||
# Seperate loops
|
||||
for weight in weights:
|
||||
X = X_START
|
||||
|
||||
steps = tf.constant(0)
|
||||
c = lambda i, x: tf.less(i, FLAGS.num_steps)
|
||||
def langevin_step(counter, X):
|
||||
scale_rate = 1
|
||||
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
|
||||
|
||||
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
||||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad * scale_rate
|
||||
X = tf.clip_by_value(X, 0, 1)
|
||||
|
||||
counter = counter + 1
|
||||
|
||||
return counter, X
|
||||
|
||||
steps, X = tf.while_loop(c, langevin_step, (steps, X))
|
||||
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
|
||||
X_final = X
|
||||
X_finals.append(X_final)
|
||||
|
||||
target_vars = {}
|
||||
target_vars['X_START'] = X_START
|
||||
target_vars['Y_GT'] = Y_GT
|
||||
target_vars['X_finals'] = X_finals
|
||||
target_vars['NOISE_SCALE'] = NOISE_SCALE
|
||||
target_vars['energy_noise'] = energy_noise
|
||||
|
||||
compute_inception(sess, target_vars)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# model_list = [117000, 116700]
|
||||
model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
|
||||
main(model_list)
|
941
EBMs/train.py
Normal file
941
EBMs/train.py
Normal file
|
@ -0,0 +1,941 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader
|
||||
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128
|
||||
import os.path as osp
|
||||
import os
|
||||
from baselines.logger import TensorBoardOutputFormat
|
||||
from utils import average_gradients, ReplayBuffer, optimistic_restore
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from torch.utils.data import DataLoader
|
||||
import time as time
|
||||
from io import StringIO
|
||||
from tensorflow.core.util import event_pb2
|
||||
import torch
|
||||
import numpy as np
|
||||
from custom_adam import AdamOptimizer
|
||||
from scipy.misc import imsave
|
||||
import matplotlib.pyplot as plt
|
||||
from hmc import hmc
|
||||
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
hvd.init()
|
||||
|
||||
from inception import get_inception_score
|
||||
|
||||
torch.manual_seed(hvd.rank())
|
||||
np.random.seed(hvd.rank())
|
||||
tf.set_random_seed(hvd.rank())
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
# Dataset Options
|
||||
flags.DEFINE_string('datasource', 'random',
|
||||
'initialization for chains, either random or default (decorruption)')
|
||||
flags.DEFINE_string('dataset','mnist',
|
||||
'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)')
|
||||
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
|
||||
flags.DEFINE_bool('single', False, 'whether to debug by training on a single image')
|
||||
flags.DEFINE_integer('data_workers', 4,
|
||||
'Number of different data workers to load data in parallel')
|
||||
|
||||
# General Experiment Settings
|
||||
flags.DEFINE_string('logdir', 'cachedir',
|
||||
'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches')
|
||||
flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches')
|
||||
flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
|
||||
flags.DEFINE_bool('train', True, 'whether to train or test')
|
||||
flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on')
|
||||
flags.DEFINE_float('lr', 3e-4, 'Learning for training')
|
||||
flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on')
|
||||
|
||||
# EBM Specific Experiments Settings
|
||||
flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients')
|
||||
flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training')
|
||||
flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models')
|
||||
flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels')
|
||||
flags.DEFINE_integer('temperature', 1, 'Temperature for energy function')
|
||||
flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),'
|
||||
'logsumexp(more stable)'
|
||||
'softplus(most stable)')
|
||||
flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss')
|
||||
|
||||
# Setting for MCMC sampling
|
||||
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
|
||||
flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection')
|
||||
flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training')
|
||||
flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent')
|
||||
flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.')
|
||||
flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models')
|
||||
flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC')
|
||||
flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead')
|
||||
|
||||
# Architecture Settings
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('large_model', False, 'whether to use a large model')
|
||||
flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network')
|
||||
flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network')
|
||||
|
||||
# Dataset settings
|
||||
flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images')
|
||||
flags.DEFINE_bool('augment', False, 'whether to augmentations to images')
|
||||
flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box')
|
||||
|
||||
# Dsprites specific experiments
|
||||
flags.DEFINE_bool('cond_shape', False, 'condition of shape type')
|
||||
flags.DEFINE_bool('cond_size', False, 'condition of shape size')
|
||||
flags.DEFINE_bool('cond_pos', False, 'condition of position loc')
|
||||
flags.DEFINE_bool('cond_rot', False, 'condition of rot')
|
||||
|
||||
FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale
|
||||
|
||||
FLAGS.batch_size *= FLAGS.num_gpus
|
||||
|
||||
print("{} batch size".format(FLAGS.batch_size))
|
||||
|
||||
|
||||
def compress_x_mod(x_mod):
|
||||
x_mod = (255 * np.clip(x_mod, 0, FLAGS.rescale) / FLAGS.rescale).astype(np.uint8)
|
||||
return x_mod
|
||||
|
||||
|
||||
def decompress_x_mod(x_mod):
|
||||
x_mod = x_mod / 256 * FLAGS.rescale + \
|
||||
np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape)
|
||||
return x_mod
|
||||
|
||||
|
||||
def make_image(tensor):
|
||||
"""Convert an numpy representation image to Image protobuf"""
|
||||
from PIL import Image
|
||||
if len(tensor.shape) == 4:
|
||||
_, height, width, channel = tensor.shape
|
||||
elif len(tensor.shape) == 3:
|
||||
height, width, channel = tensor.shape
|
||||
elif len(tensor.shape) == 2:
|
||||
height, width = tensor.shape
|
||||
channel = 1
|
||||
tensor = tensor.astype(np.uint8)
|
||||
image = Image.fromarray(tensor)
|
||||
import io
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
image_string = output.getvalue()
|
||||
output.close()
|
||||
return tf.Summary.Image(height=height,
|
||||
width=width,
|
||||
colorspace=channel,
|
||||
encoded_image_string=image_string)
|
||||
|
||||
|
||||
def log_image(im, logger, tag, step=0):
|
||||
im = make_image(im)
|
||||
|
||||
summary = [tf.Summary.Value(tag=tag, image=im)]
|
||||
summary = tf.Summary(value=summary)
|
||||
event = event_pb2.Event(summary=summary)
|
||||
event.step = step
|
||||
logger.writer.WriteEvent(event)
|
||||
logger.writer.Flush()
|
||||
|
||||
|
||||
def rescale_im(image):
|
||||
image = np.clip(image, 0, FLAGS.rescale)
|
||||
if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites':
|
||||
return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
|
||||
else:
|
||||
return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
|
||||
|
||||
|
||||
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
|
||||
X = target_vars['X']
|
||||
Y = target_vars['Y']
|
||||
X_NOISE = target_vars['X_NOISE']
|
||||
train_op = target_vars['train_op']
|
||||
energy_pos = target_vars['energy_pos']
|
||||
energy_neg = target_vars['energy_neg']
|
||||
loss_energy = target_vars['loss_energy']
|
||||
loss_ml = target_vars['loss_ml']
|
||||
loss_total = target_vars['total_loss']
|
||||
gvs = target_vars['gvs']
|
||||
x_grad = target_vars['x_grad']
|
||||
x_grad_first = target_vars['x_grad_first']
|
||||
x_off = target_vars['x_off']
|
||||
temp = target_vars['temp']
|
||||
x_mod = target_vars['x_mod']
|
||||
LABEL = target_vars['LABEL']
|
||||
LABEL_POS = target_vars['LABEL_POS']
|
||||
weights = target_vars['weights']
|
||||
test_x_mod = target_vars['test_x_mod']
|
||||
eps = target_vars['eps_begin']
|
||||
label_ent = target_vars['label_ent']
|
||||
|
||||
if FLAGS.use_attention:
|
||||
gamma = weights[0]['atten']['gamma']
|
||||
else:
|
||||
gamma = tf.zeros(1)
|
||||
|
||||
val_output = [test_x_mod]
|
||||
|
||||
gvs_dict = dict(gvs)
|
||||
|
||||
log_output = [
|
||||
train_op,
|
||||
energy_pos,
|
||||
energy_neg,
|
||||
eps,
|
||||
loss_energy,
|
||||
loss_ml,
|
||||
loss_total,
|
||||
x_grad,
|
||||
x_off,
|
||||
x_mod,
|
||||
gamma,
|
||||
x_grad_first,
|
||||
label_ent,
|
||||
*gvs_dict.keys()]
|
||||
output = [train_op, x_mod]
|
||||
|
||||
replay_buffer = ReplayBuffer(10000)
|
||||
itr = resume_iter
|
||||
x_mod = None
|
||||
gd_steps = 1
|
||||
|
||||
dataloader_iterator = iter(dataloader)
|
||||
best_inception = 0.0
|
||||
|
||||
for epoch in range(FLAGS.epoch_num):
|
||||
for data_corrupt, data, label in dataloader:
|
||||
data_corrupt = data_corrupt_init = data_corrupt.numpy()
|
||||
data_corrupt_init = data_corrupt.copy()
|
||||
|
||||
data = data.numpy()
|
||||
label = label.numpy()
|
||||
|
||||
label_init = label.copy()
|
||||
|
||||
if FLAGS.mixup:
|
||||
idx = np.random.permutation(data.shape[0])
|
||||
lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
|
||||
data = data * lam + data[idx] * (1 - lam)
|
||||
|
||||
if FLAGS.replay_batch and (x_mod is not None):
|
||||
replay_buffer.add(compress_x_mod(x_mod))
|
||||
|
||||
if len(replay_buffer) > FLAGS.batch_size:
|
||||
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
||||
replay_batch = decompress_x_mod(replay_batch)
|
||||
replay_mask = (
|
||||
np.random.uniform(
|
||||
0,
|
||||
FLAGS.rescale,
|
||||
FLAGS.batch_size) > 0.05)
|
||||
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
||||
|
||||
if FLAGS.pcd:
|
||||
if x_mod is not None:
|
||||
data_corrupt = x_mod
|
||||
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}
|
||||
|
||||
if FLAGS.cclass:
|
||||
feed_dict[LABEL] = label
|
||||
feed_dict[LABEL_POS] = label_init
|
||||
|
||||
if itr % FLAGS.log_interval == 0:
|
||||
_, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
|
||||
grads = sess.run(log_output, feed_dict)
|
||||
|
||||
kvs = {}
|
||||
kvs['e_pos'] = e_pos.mean()
|
||||
kvs['e_pos_std'] = e_pos.std()
|
||||
kvs['e_neg'] = e_neg.mean()
|
||||
kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
|
||||
kvs['e_neg_std'] = e_neg.std()
|
||||
kvs['temp'] = temp
|
||||
kvs['loss_e'] = loss_e.mean()
|
||||
kvs['eps'] = eps.mean()
|
||||
kvs['label_ent'] = label_ent
|
||||
kvs['loss_ml'] = loss_ml.mean()
|
||||
kvs['loss_total'] = loss_total.mean()
|
||||
kvs['x_grad'] = np.abs(x_grad).mean()
|
||||
kvs['x_grad_first'] = np.abs(x_grad_first).mean()
|
||||
kvs['x_off'] = x_off.mean()
|
||||
kvs['iter'] = itr
|
||||
kvs['gamma'] = gamma
|
||||
|
||||
for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
|
||||
kvs[k] = np.abs(v).max()
|
||||
|
||||
string = "Obtained a total of "
|
||||
for key, value in kvs.items():
|
||||
string += "{}: {}, ".format(key, value)
|
||||
|
||||
if hvd.rank() == 0:
|
||||
print(string)
|
||||
logger.writekvs(kvs)
|
||||
else:
|
||||
_, x_mod = sess.run(output, feed_dict)
|
||||
|
||||
if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
|
||||
saver.save(
|
||||
sess,
|
||||
osp.join(
|
||||
FLAGS.logdir,
|
||||
FLAGS.exp,
|
||||
'model_{}'.format(itr)))
|
||||
|
||||
if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
|
||||
try_im = x_mod
|
||||
orig_im = data_corrupt.squeeze()
|
||||
actual_im = rescale_im(data)
|
||||
|
||||
orig_im = rescale_im(orig_im)
|
||||
try_im = rescale_im(try_im).squeeze()
|
||||
|
||||
for i, (im, t_im, actual_im_i) in enumerate(
|
||||
zip(orig_im[:20], try_im[:20], actual_im)):
|
||||
shape = orig_im.shape[1:]
|
||||
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
|
||||
size = shape[1]
|
||||
new_im[:, :size] = im
|
||||
new_im[:, size:2 * size] = t_im
|
||||
new_im[:, 2 * size:] = actual_im_i
|
||||
|
||||
log_image(
|
||||
new_im, logger, 'train_gen_{}'.format(itr), step=i)
|
||||
|
||||
test_im = x_mod
|
||||
|
||||
try:
|
||||
data_corrupt, data, label = next(dataloader_iterator)
|
||||
except BaseException:
|
||||
dataloader_iterator = iter(dataloader)
|
||||
data_corrupt, data, label = next(dataloader_iterator)
|
||||
|
||||
data_corrupt = data_corrupt.numpy()
|
||||
|
||||
if FLAGS.replay_batch and (
|
||||
x_mod is not None) and len(replay_buffer) > 0:
|
||||
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
||||
replay_batch = decompress_x_mod(replay_batch)
|
||||
replay_mask = (
|
||||
np.random.uniform(
|
||||
0, 1, (FLAGS.batch_size)) > 0.05)
|
||||
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
||||
|
||||
if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
|
||||
n = 128
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
n = 32
|
||||
|
||||
if len(replay_buffer) > n:
|
||||
data_corrupt = decompress_x_mod(replay_buffer.sample(n))
|
||||
elif FLAGS.dataset == 'imagenetfull':
|
||||
data_corrupt = np.random.uniform(
|
||||
0, FLAGS.rescale, (n, 128, 128, 3))
|
||||
else:
|
||||
data_corrupt = np.random.uniform(
|
||||
0, FLAGS.rescale, (n, 32, 32, 3))
|
||||
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
label = np.eye(10)[np.random.randint(0, 10, (n))]
|
||||
else:
|
||||
label = np.eye(1000)[
|
||||
np.random.randint(
|
||||
0, 1000, (n))]
|
||||
|
||||
feed_dict[X_NOISE] = data_corrupt
|
||||
|
||||
feed_dict[X] = data
|
||||
|
||||
if FLAGS.cclass:
|
||||
feed_dict[LABEL] = label
|
||||
|
||||
test_x_mod = sess.run(val_output, feed_dict)
|
||||
|
||||
try_im = test_x_mod
|
||||
orig_im = data_corrupt.squeeze()
|
||||
actual_im = rescale_im(data.numpy())
|
||||
|
||||
orig_im = rescale_im(orig_im)
|
||||
try_im = rescale_im(try_im).squeeze()
|
||||
|
||||
for i, (im, t_im, actual_im_i) in enumerate(
|
||||
zip(orig_im[:20], try_im[:20], actual_im)):
|
||||
|
||||
shape = orig_im.shape[1:]
|
||||
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
|
||||
size = shape[1]
|
||||
new_im[:, :size] = im
|
||||
new_im[:, size:2 * size] = t_im
|
||||
new_im[:, 2 * size:] = actual_im_i
|
||||
log_image(
|
||||
new_im, logger, 'val_gen_{}'.format(itr), step=i)
|
||||
|
||||
score, std = get_inception_score(list(try_im), splits=1)
|
||||
print(
|
||||
"Inception score of {} with std of {}".format(
|
||||
score, std))
|
||||
kvs = {}
|
||||
kvs['inception_score'] = score
|
||||
kvs['inception_score_std'] = std
|
||||
logger.writekvs(kvs)
|
||||
|
||||
if score > best_inception:
|
||||
best_inception = score
|
||||
saver.save(
|
||||
sess,
|
||||
osp.join(
|
||||
FLAGS.logdir,
|
||||
FLAGS.exp,
|
||||
'model_best'))
|
||||
|
||||
if itr > 60000 and FLAGS.dataset == "mnist":
|
||||
assert False
|
||||
itr += 1
|
||||
|
||||
saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
|
||||
|
||||
|
||||
cifar10_map = {0: 'airplane',
|
||||
1: 'automobile',
|
||||
2: 'bird',
|
||||
3: 'cat',
|
||||
4: 'deer',
|
||||
5: 'dog',
|
||||
6: 'frog',
|
||||
7: 'horse',
|
||||
8: 'ship',
|
||||
9: 'truck'}
|
||||
|
||||
|
||||
def test(target_vars, saver, sess, logger, dataloader):
|
||||
X_NOISE = target_vars['X_NOISE']
|
||||
X = target_vars['X']
|
||||
Y = target_vars['Y']
|
||||
LABEL = target_vars['LABEL']
|
||||
energy_start = target_vars['energy_start']
|
||||
x_mod = target_vars['x_mod']
|
||||
x_mod = target_vars['test_x_mod']
|
||||
energy_neg = target_vars['energy_neg']
|
||||
|
||||
np.random.seed(1)
|
||||
random.seed(1)
|
||||
|
||||
output = [x_mod, energy_start, energy_neg]
|
||||
|
||||
dataloader_iterator = iter(dataloader)
|
||||
data_corrupt, data, label = next(dataloader_iterator)
|
||||
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
|
||||
|
||||
orig_im = try_im = data_corrupt
|
||||
|
||||
if FLAGS.cclass:
|
||||
try_im, energy_orig, energy = sess.run(
|
||||
output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label})
|
||||
else:
|
||||
try_im, energy_orig, energy = sess.run(
|
||||
output, {X_NOISE: orig_im, Y: label[0:1]})
|
||||
|
||||
orig_im = rescale_im(orig_im)
|
||||
try_im = rescale_im(try_im)
|
||||
actual_im = rescale_im(data)
|
||||
|
||||
for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate(
|
||||
zip(orig_im, energy_orig, try_im, energy, label, actual_im)):
|
||||
label_i = np.array(label_i)
|
||||
|
||||
shape = im.shape[1:]
|
||||
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
|
||||
size = shape[1]
|
||||
new_im[:, :size] = im
|
||||
new_im[:, size:2 * size] = t_im
|
||||
|
||||
if FLAGS.cclass:
|
||||
label_i = np.where(label_i == 1)[0][0]
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format(
|
||||
i, energy_i[0], energy[0], cifar10_map[label_i]), step=i)
|
||||
else:
|
||||
log_image(
|
||||
new_im,
|
||||
logger,
|
||||
'{}_{:.4f}_now_{:.4f}_{}'.format(
|
||||
i,
|
||||
energy_i[0],
|
||||
energy[0],
|
||||
label_i),
|
||||
step=i)
|
||||
else:
|
||||
log_image(
|
||||
new_im,
|
||||
logger,
|
||||
'{}_{:.4f}_now_{:.4f}'.format(
|
||||
i,
|
||||
energy_i[0],
|
||||
energy[0]),
|
||||
step=i)
|
||||
|
||||
test_ims = list(try_im)
|
||||
real_ims = list(actual_im)
|
||||
|
||||
for i in tqdm(range(50000 // FLAGS.batch_size + 1)):
|
||||
try:
|
||||
data_corrupt, data, label = dataloader_iterator.next()
|
||||
except BaseException:
|
||||
dataloader_iterator = iter(dataloader)
|
||||
data_corrupt, data, label = dataloader_iterator.next()
|
||||
|
||||
data_corrupt, data, label = data_corrupt.numpy(), data.numpy(), label.numpy()
|
||||
|
||||
if FLAGS.cclass:
|
||||
try_im, energy_orig, energy = sess.run(
|
||||
output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label})
|
||||
else:
|
||||
try_im, energy_orig, energy = sess.run(
|
||||
output, {X_NOISE: data_corrupt, Y: label[0:1]})
|
||||
|
||||
try_im = rescale_im(try_im)
|
||||
real_im = rescale_im(data)
|
||||
|
||||
test_ims.extend(list(try_im))
|
||||
real_ims.extend(list(real_im))
|
||||
|
||||
score, std = get_inception_score(test_ims)
|
||||
print("Inception score of {} with std of {}".format(score, std))
|
||||
|
||||
|
||||
def main():
|
||||
print("Local rank: ", hvd.local_rank(), hvd.size())
|
||||
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
if hvd.rank() == 0:
|
||||
if not osp.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
logger = TensorBoardOutputFormat(logdir)
|
||||
else:
|
||||
logger = None
|
||||
|
||||
LABEL = None
|
||||
print("Loading data...")
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
|
||||
test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 3
|
||||
|
||||
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
||||
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
||||
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
||||
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(
|
||||
num_channels=channel_num,
|
||||
num_filters=128,
|
||||
train=True)
|
||||
elif FLAGS.larger_model:
|
||||
model = ResNet32Larger(
|
||||
num_channels=channel_num,
|
||||
num_filters=128)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(
|
||||
num_channels=channel_num,
|
||||
num_filters=192)
|
||||
else:
|
||||
model = ResNet32(
|
||||
num_channels=channel_num,
|
||||
num_filters=128)
|
||||
|
||||
elif FLAGS.dataset == 'imagenet':
|
||||
dataset = Imagenet(train=True)
|
||||
test_dataset = Imagenet(train=False)
|
||||
channel_num = 3
|
||||
X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
||||
X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
||||
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
|
||||
model = ResNet32Wider(
|
||||
num_channels=channel_num,
|
||||
num_filters=256)
|
||||
|
||||
elif FLAGS.dataset == 'imagenetfull':
|
||||
channel_num = 3
|
||||
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
||||
X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
||||
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
|
||||
model = ResNet128(
|
||||
num_channels=channel_num,
|
||||
num_filters=64)
|
||||
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
dataset = Mnist(rescale=FLAGS.rescale)
|
||||
test_dataset = dataset
|
||||
channel_num = 1
|
||||
X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
|
||||
X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
|
||||
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
||||
|
||||
model = MnistNet(
|
||||
num_channels=channel_num,
|
||||
num_filters=FLAGS.num_filters)
|
||||
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
dataset = DSprites(
|
||||
cond_shape=FLAGS.cond_shape,
|
||||
cond_size=FLAGS.cond_size,
|
||||
cond_pos=FLAGS.cond_pos,
|
||||
cond_rot=FLAGS.cond_rot)
|
||||
test_dataset = dataset
|
||||
channel_num = 1
|
||||
|
||||
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
if FLAGS.dpos_only:
|
||||
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
elif FLAGS.dsize_only:
|
||||
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
||||
elif FLAGS.drot_only:
|
||||
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
elif FLAGS.cond_size:
|
||||
LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
|
||||
elif FLAGS.cond_shape:
|
||||
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
elif FLAGS.cond_pos:
|
||||
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
elif FLAGS.cond_rot:
|
||||
LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
else:
|
||||
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
|
||||
model = DspritesNet(
|
||||
num_channels=channel_num,
|
||||
num_filters=FLAGS.num_filters,
|
||||
cond_size=FLAGS.cond_size,
|
||||
cond_shape=FLAGS.cond_shape,
|
||||
cond_pos=FLAGS.cond_pos,
|
||||
cond_rot=FLAGS.cond_rot)
|
||||
|
||||
print("Done loading...")
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
# In the case of full imagenet, use custom_tensorflow dataloader
|
||||
data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=FLAGS.data_workers,
|
||||
drop_last=True,
|
||||
shuffle=True)
|
||||
|
||||
batch_size = FLAGS.batch_size
|
||||
|
||||
weights = [model.construct_weights('context_0')]
|
||||
|
||||
Y = tf.placeholder(shape=(None), dtype=tf.int32)
|
||||
|
||||
# Varibles to run in training
|
||||
X_SPLIT = tf.split(X, FLAGS.num_gpus)
|
||||
X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
|
||||
LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
|
||||
LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
|
||||
LABEL_SPLIT_INIT = list(LABEL_SPLIT)
|
||||
tower_grads = []
|
||||
tower_gen_grads = []
|
||||
x_mod_list = []
|
||||
|
||||
optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
|
||||
optimizer = hvd.DistributedOptimizer(optimizer)
|
||||
|
||||
for j in range(FLAGS.num_gpus):
|
||||
|
||||
if FLAGS.model_cclass:
|
||||
ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
|
||||
label_tensor = tf.Variable(
|
||||
tf.convert_to_tensor(
|
||||
np.reshape(
|
||||
np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
|
||||
(FLAGS.batch_size * 10, 10)),
|
||||
dtype=tf.float32),
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
x_split = tf.tile(
|
||||
tf.reshape(
|
||||
X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
|
||||
x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
|
||||
energy_pos = model.forward(
|
||||
x_split,
|
||||
weights[0],
|
||||
label=label_tensor,
|
||||
stop_at_grad=False)
|
||||
|
||||
energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
|
||||
energy_partition_est = tf.reduce_logsumexp(
|
||||
energy_pos_full, axis=1, keepdims=True)
|
||||
uniform = tf.random_uniform(tf.shape(energy_pos_full))
|
||||
label_tensor = tf.argmax(-energy_pos_full -
|
||||
tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
|
||||
label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
|
||||
label = tf.Print(label, [label_tensor, energy_pos_full])
|
||||
LABEL_SPLIT[j] = label
|
||||
energy_pos = tf.concat(energy_pos, axis=0)
|
||||
else:
|
||||
energy_pos = [
|
||||
model.forward(
|
||||
X_SPLIT[j],
|
||||
weights[0],
|
||||
label=LABEL_POS_SPLIT[j],
|
||||
stop_at_grad=False)]
|
||||
energy_pos = tf.concat(energy_pos, axis=0)
|
||||
|
||||
print("Building graph...")
|
||||
x_mod = x_orig = X_NOISE_SPLIT[j]
|
||||
|
||||
x_grads = []
|
||||
|
||||
energy_negs = []
|
||||
loss_energys = []
|
||||
|
||||
energy_negs.extend([model.forward(tf.stop_gradient(
|
||||
x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
|
||||
eps_begin = tf.zeros(1)
|
||||
|
||||
steps = tf.constant(0)
|
||||
c = lambda i, x: tf.less(i, FLAGS.num_steps)
|
||||
|
||||
def langevin_step(counter, x_mod):
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
|
||||
mean=0.0,
|
||||
stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)
|
||||
|
||||
energy_noise = energy_start = tf.concat(
|
||||
[model.forward(
|
||||
x_mod,
|
||||
weights[0],
|
||||
label=LABEL_SPLIT[j],
|
||||
reuse=True,
|
||||
stop_at_grad=False,
|
||||
stop_batch=True)],
|
||||
axis=0)
|
||||
|
||||
x_grad, label_grad = tf.gradients(
|
||||
FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
|
||||
energy_noise_old = energy_noise
|
||||
|
||||
lr = FLAGS.step_lr
|
||||
|
||||
if FLAGS.proj_norm != 0.0:
|
||||
if FLAGS.proj_norm_type == 'l2':
|
||||
x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
|
||||
elif FLAGS.proj_norm_type == 'li':
|
||||
x_grad = tf.clip_by_value(
|
||||
x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
else:
|
||||
print("Other types of projection are not supported!!!")
|
||||
assert False
|
||||
|
||||
# Clip gradient norm for now
|
||||
if FLAGS.hmc:
|
||||
# Step size should be tuned to get around 65% acceptance
|
||||
def energy(x):
|
||||
return FLAGS.temperature * \
|
||||
model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)
|
||||
|
||||
x_last = hmc(x_mod, 15., 10, energy)
|
||||
else:
|
||||
x_last = x_mod - (lr) * x_grad
|
||||
|
||||
x_mod = x_last
|
||||
x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)
|
||||
|
||||
counter = counter + 1
|
||||
|
||||
return counter, x_mod
|
||||
|
||||
steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))
|
||||
|
||||
energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
|
||||
stop_at_grad=False, reuse=True)
|
||||
x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
|
||||
x_grads.append(x_grad)
|
||||
|
||||
energy_negs.append(
|
||||
model.forward(
|
||||
tf.stop_gradient(x_mod),
|
||||
weights[0],
|
||||
label=LABEL_SPLIT[j],
|
||||
stop_at_grad=False,
|
||||
reuse=True))
|
||||
|
||||
test_x_mod = x_mod
|
||||
|
||||
temp = FLAGS.temperature
|
||||
|
||||
energy_neg = energy_negs[-1]
|
||||
x_off = tf.reduce_mean(
|
||||
tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))
|
||||
|
||||
loss_energy = model.forward(
|
||||
x_mod,
|
||||
weights[0],
|
||||
reuse=True,
|
||||
label=LABEL,
|
||||
stop_grad=True)
|
||||
|
||||
print("Finished processing loop construction ...")
|
||||
|
||||
target_vars = {}
|
||||
|
||||
if FLAGS.cclass or FLAGS.model_cclass:
|
||||
label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
|
||||
label_prob = label_sum / tf.reduce_sum(label_sum)
|
||||
label_ent = -tf.reduce_sum(label_prob *
|
||||
tf.math.log(label_prob + 1e-7))
|
||||
else:
|
||||
label_ent = tf.zeros(1)
|
||||
|
||||
target_vars['label_ent'] = label_ent
|
||||
|
||||
if FLAGS.train:
|
||||
|
||||
if FLAGS.objective == 'logsumexp':
|
||||
pos_term = temp * energy_pos
|
||||
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
|
||||
coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
|
||||
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
||||
pos_loss = tf.reduce_mean(temp * energy_pos)
|
||||
neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
|
||||
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
|
||||
elif FLAGS.objective == 'cd':
|
||||
pos_loss = tf.reduce_mean(temp * energy_pos)
|
||||
neg_loss = -tf.reduce_mean(temp * energy_neg)
|
||||
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
|
||||
elif FLAGS.objective == 'softplus':
|
||||
loss_ml = FLAGS.ml_coeff * \
|
||||
tf.nn.softplus(temp * (energy_pos - energy_neg))
|
||||
|
||||
loss_total = tf.reduce_mean(loss_ml)
|
||||
|
||||
if not FLAGS.zero_kl:
|
||||
loss_total = loss_total + tf.reduce_mean(loss_energy)
|
||||
|
||||
loss_total = loss_total + \
|
||||
FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))
|
||||
|
||||
print("Started gradient computation...")
|
||||
gvs = optimizer.compute_gradients(loss_total)
|
||||
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
||||
|
||||
print("Applying gradients...")
|
||||
|
||||
tower_grads.append(gvs)
|
||||
|
||||
print("Finished applying gradients.")
|
||||
|
||||
target_vars['loss_ml'] = loss_ml
|
||||
target_vars['total_loss'] = loss_total
|
||||
target_vars['loss_energy'] = loss_energy
|
||||
target_vars['weights'] = weights
|
||||
target_vars['gvs'] = gvs
|
||||
|
||||
target_vars['X'] = X
|
||||
target_vars['Y'] = Y
|
||||
target_vars['LABEL'] = LABEL
|
||||
target_vars['LABEL_POS'] = LABEL_POS
|
||||
target_vars['X_NOISE'] = X_NOISE
|
||||
target_vars['energy_pos'] = energy_pos
|
||||
target_vars['energy_start'] = energy_negs[0]
|
||||
|
||||
if len(x_grads) >= 1:
|
||||
target_vars['x_grad'] = x_grads[-1]
|
||||
target_vars['x_grad_first'] = x_grads[0]
|
||||
else:
|
||||
target_vars['x_grad'] = tf.zeros(1)
|
||||
target_vars['x_grad_first'] = tf.zeros(1)
|
||||
|
||||
target_vars['x_mod'] = x_mod
|
||||
target_vars['x_off'] = x_off
|
||||
target_vars['temp'] = temp
|
||||
target_vars['energy_neg'] = energy_neg
|
||||
target_vars['test_x_mod'] = test_x_mod
|
||||
target_vars['eps_begin'] = eps_begin
|
||||
|
||||
if FLAGS.train:
|
||||
grads = average_gradients(tower_grads)
|
||||
train_op = optimizer.apply_gradients(grads)
|
||||
target_vars['train_op'] = train_op
|
||||
|
||||
config = tf.ConfigProto()
|
||||
|
||||
if hvd.size() > 1:
|
||||
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
|
||||
sess = tf.Session(config=config)
|
||||
|
||||
saver = loader = tf.train.Saver(
|
||||
max_to_keep=30, keep_checkpoint_every_n_hours=6)
|
||||
|
||||
total_parameters = 0
|
||||
for variable in tf.trainable_variables():
|
||||
# shape is an array of tf.Dimension
|
||||
shape = variable.get_shape()
|
||||
variable_parameters = 1
|
||||
for dim in shape:
|
||||
variable_parameters *= dim.value
|
||||
total_parameters += variable_parameters
|
||||
print("Model has a total of {} parameters".format(total_parameters))
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
resume_itr = 0
|
||||
|
||||
if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
# saver.restore(sess, model_file)
|
||||
optimistic_restore(sess, model_file)
|
||||
|
||||
sess.run(hvd.broadcast_global_variables(0))
|
||||
print("Initializing variables...")
|
||||
|
||||
print("Start broadcast")
|
||||
print("End broadcast")
|
||||
|
||||
if FLAGS.train:
|
||||
train(target_vars, saver, sess,
|
||||
logger, data_loader, resume_itr,
|
||||
logdir)
|
||||
|
||||
test(target_vars, saver, sess, logger, data_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1107
EBMs/utils.py
Normal file
1107
EBMs/utils.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue