chores: refactor for the new ai research, add linter, gh action, etc (#27)

This commit is contained in:
Marina von Steinkirch, PhD 2025-08-13 21:49:46 +08:00 committed by von-steinkirch
parent fb4ab80dc3
commit d5467e559f
40 changed files with 5177 additions and 2476 deletions

View file

@ -1,12 +1,15 @@
## quantum ai: training energy-based-models using openai
## quantum ai: training energy-based-models using openAI
<br>
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in
energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
<br>
---
### installing
<br>
@ -19,7 +22,8 @@ brew install pkg-config
<br>
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with:
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this
problem (`PMIX ERROR: ERROR`) that can be fixed with:
<br>
@ -40,7 +44,8 @@ pip install -r requirements.txt
```
<br>
* note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
* note that this is an adapted requirement file since the **[openai's
original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
* finally, download and install **[mujoco](https://www.roboti.us/index.html)**
* you will also need to register for a license, which asks for a machine ID
* the documentation on the website is incomplete, so just download the suggested script and run:
@ -64,7 +69,8 @@ mv getid_osx getid_osx.dms
<br>
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`:
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder
`cachedir`:
<br>
@ -78,7 +84,8 @@ mkdir cachedir
<br>
* openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
* openai's original code contains **[hardcoded constants that only work on
Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
* i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
<br>
@ -87,7 +94,8 @@ mkdir cachedir
<br>
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command:
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased
substantially by using multiple different workers by running each command:
<br>
@ -102,7 +110,8 @@ mpiexec -n <worker_num> <command>
<br>
```
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
--zero_kl --replay_batch --large_model
```
* this should generate the following output:
@ -112,7 +121,8 @@ python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_si
```bash
Instructions for updating:
Use tf.gfile.GFile.
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is
deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
64 batch size
Local rank: 0 1
Loading data...
@ -121,11 +131,15 @@ Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Done loading...
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
WARNING:tensorflow:From
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263:
colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Building graph...
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
WARNING:tensorflow:From
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from
tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Finished processing loop construction ...
@ -136,16 +150,36 @@ Model has a total of 7567880 parameters
Initializing variables...
Start broadcast
End broadcast
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g:
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff:
0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent:
2.272536277770996, l
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first:
0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818,
context_0/res_optim_res_c1/
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10,
context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10,
context_0/res_optim_res_c2/gb:0: 6.27235652306268
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11,
context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437,
context_0/res_1_res_c2/g:0
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0:
1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0:
6.868505764145993e-10, context_0/res_2
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0:
8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0:
0.0194275379180908
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10,
context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11,
context_0/res_3_res_c2/gb:0: 7.75612463
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0:
4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0:
0.34806951880455017, context_0/res_4_res_c2/g:
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0:
5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0:
3.807749116013781e-10, context_0/res_5
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0:
7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039,
context_0/fc5/
b:0: 0.4506262540817261,
................................................................................................................................
@ -159,7 +193,8 @@ Inception score of 1.2397289276123047 with std of 0.0
<br>
```
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
--zero_kl --replay_batch --cclass
```
<br>
@ -169,7 +204,8 @@ python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size
<br>
```
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01
--replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
```
<br>
@ -179,7 +215,8 @@ python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=3
<br>
```
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass
--zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
```
<br>
@ -197,7 +234,8 @@ python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
```
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by
different settings of task flag in the file
* for example, to visualize cross class mappings in cifar-10, you can run:
<br>
@ -217,7 +255,8 @@ python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resu
<br>
```
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200
--large_model --svhnmix --cclass=False
```
<br>
@ -227,7 +266,8 @@ python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_
<br>
```
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd steps> --num_steps=10 --lival=<li bound value> --wider_model
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd
steps> --num_steps=10 --lival=<li bound value> --wider_model
```
<br>
@ -236,12 +276,14 @@ python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=
<br>
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in
`cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
<br>
```
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act
--cond_pos --replay_batch -cclass
```
<br>
@ -249,5 +291,7 @@ python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps
* once models are trained, they can be sampled from jointly by running:
```
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos> --exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot> --resume_pos=<resume_pos>
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos>
--exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot>
--resume_pos=<resume_pos>
```

View file

@ -1,40 +1,65 @@
import tensorflow as tf
import math
import os.path as osp
import numpy as np
import tensorflow as tf
from data import Cifar10, DSprites, Mnist
from hmc import hmc
from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
from data import Cifar10, Mnist, DSprites
from scipy.misc import logsumexp
from scipy.misc import imsave
from utils import optimistic_restore
import os.path as osp
import numpy as np
from tqdm import tqdm
from utils import optimistic_restore
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
flags.DEFINE_string(
"dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
)
flags.DEFINE_string(
"logdir", "cachedir", "location where log of experiments will be stored"
)
flags.DEFINE_string("exp", "default", "name of experiments")
flags.DEFINE_integer(
"data_workers", 5, "Number of different data workers to load data in parallel"
)
flags.DEFINE_integer("batch_size", 16, "Size of inputs")
flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
flags.DEFINE_bool(
"max_pool",
False,
"Whether or not to use max pooling rather than strided convolutions",
)
flags.DEFINE_integer(
"num_filters",
64,
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
)
flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
flags.DEFINE_integer(
"rescale", 1, "factor to rescale input outside of normal (0, 1) box"
)
flags.DEFINE_float(
"temperature", 1, "temperature at which to compute likelihood of model"
)
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
flags.DEFINE_bool(
"cclass",
False,
"Whether to evaluate the log likelihood of conditional model or not",
)
flags.DEFINE_bool(
"single",
False,
"Whether to evaluate the log likelihood of conditional model or not",
)
flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
FLAGS = flags.FLAGS
@ -45,11 +70,12 @@ label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
def unscale_im(im):
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
def gauss_prob_log(x, prec=1.0):
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
return norm_constant_log + prob_density_log
@ -73,23 +99,36 @@ def model_prob_log(x, e_func, weights, temp):
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
if FLAGS.dataset == "gauss":
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
x, prec=FLAGS.temperature
)
else:
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
x, e_func, weights, temp
)
# Add an additional log likelihood penalty so that points outside of (0,
# 1) box are *highly* unlikely
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
elif FLAGS.dataset == 'mnist':
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
oob_prob = tf.reduce_sum(
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
)
elif FLAGS.dataset == "mnist":
oob_prob = tf.reduce_sum(
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
)
else:
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
oob_prob = tf.reduce_sum(
tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
axis=[1, 2, 3],
)
return -norm_prob + oob_prob
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
def ancestral_sample(
e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
):
if FLAGS.dataset == "2d":
x = tf.placeholder(tf.float32, shape=(None, 2))
elif FLAGS.dataset == "gauss":
@ -130,41 +169,46 @@ def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_
def main():
# Initialize dataset
if FLAGS.dataset == 'cifar10':
if FLAGS.dataset == "cifar10":
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3
dim_input = 32 * 32 * 3
elif FLAGS.dataset == 'imagenet':
32 * 32 * 3
elif FLAGS.dataset == "imagenet":
dataset = ImagenetClass()
channel_num = 3
dim_input = 64 * 64 * 3
elif FLAGS.dataset == 'mnist':
64 * 64 * 3
elif FLAGS.dataset == "mnist":
dataset = Mnist(train=False, rescale=FLAGS.rescale)
channel_num = 1
dim_input = 28 * 28 * 1
elif FLAGS.dataset == 'dsprites':
28 * 28 * 1
elif FLAGS.dataset == "dsprites":
dataset = DSprites()
channel_num = 1
dim_input = 64 * 64 * 1
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
64 * 64 * 1
elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
dataset = Box2D()
dim_output = 1
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
data_loader = DataLoader(
dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.data_workers,
drop_last=False,
shuffle=True,
)
if FLAGS.dataset == 'mnist':
if FLAGS.dataset == "mnist":
model = MnistNet(num_channels=channel_num)
elif FLAGS.dataset == 'cifar10':
elif FLAGS.dataset == "cifar10":
if FLAGS.large_model:
model = ResNet32Large(num_filters=128)
elif FLAGS.wider_model:
model = ResNet32Wider(num_filters=192)
else:
model = ResNet32(num_channels=channel_num, num_filters=128)
elif FLAGS.dataset == 'dsprites':
elif FLAGS.dataset == "dsprites":
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
weights = model.construct_weights('context_{}'.format(0))
weights = model.construct_weights("context_{}".format(0))
config = tf.ConfigProto()
sess = tf.Session(config=config)
@ -173,8 +217,8 @@ def main():
sess.run(tf.global_variables_initializer())
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
FLAGS.resume_iter
if FLAGS.resume_iter != "-1":
optimistic_restore(sess, model_file)
@ -182,14 +226,17 @@ def main():
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
# saver.restore(sess, model_file)
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
model, weights, FLAGS.batch_size, temp=FLAGS.temperature
)
print("Finished constructing ancestral sample ...................")
if FLAGS.dataset != "gauss":
comb_weights_cum = []
batch_size = tf.shape(x_init)[0]
label_tiled = tf.tile(label_default, (batch_size, 1))
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
e_compute = -FLAGS.temperature * model.forward(
x_init, weights, label=label_tiled
)
e_pos_list = []
for data_corrupt, data, label_gt in tqdm(data_loader):
@ -205,44 +252,75 @@ def main():
alr = 0.0085
elif FLAGS.dataset == "mnist":
alr = 0.0065
#90 alr = 0.0035
# 90 alr = 0.0035
else:
# alr = 0.0125
if FLAGS.rescale == 8:
alr = 0.0085
else:
alr = 0.0045
#
#
for i in range(1):
tot_weight = 0
for j in tqdm(range(1, FLAGS.pdist+1)):
for j in tqdm(range(1, FLAGS.pdist + 1)):
if j == 1:
if FLAGS.dataset == "cifar10":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
)
elif FLAGS.dataset == "gauss":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
)
elif FLAGS.dataset == "mnist":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
)
else:
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
)
alpha_prev = (j-1) / FLAGS.pdist
alpha_prev = (j - 1) / FLAGS.pdist
alpha_new = j / FLAGS.pdist
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
cweight, x_curr = sess.run(
[chain_weights, x],
{
a_prev: alpha_prev,
a_new: alpha_new,
x_init: x_curr,
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
},
)
tot_weight = tot_weight + cweight
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
print(
"Total values of lower value based off forward sampling",
np.mean(tot_weight),
np.std(tot_weight),
)
tot_weight = 0
for j in tqdm(range(FLAGS.pdist, 0, -1)):
alpha_new = (j-1) / FLAGS.pdist
alpha_new = (j - 1) / FLAGS.pdist
alpha_prev = j / FLAGS.pdist
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
cweight, x_curr = sess.run(
[chain_weights, x],
{
a_prev: alpha_prev,
a_new: alpha_new,
x_init: x_curr,
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
},
)
tot_weight = tot_weight - cweight
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
print(
"Total values of upper value based off backward sampling",
np.mean(tot_weight),
np.std(tot_weight),
)
if __name__ == "__main__":

View file

@ -14,223 +14,244 @@
# ==============================================================================
"""Adam for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
from tensorflow.python.ops import (
control_flow_ops,
math_ops,
resource_variable_ops,
state_ops,
)
from tensorflow.python.training import optimizer, training_ops
from tensorflow.python.util.tf_export import tf_export
import tensorflow as tf
@tf_export("train.AdamOptimizer")
class AdamOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adam algorithm.
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
Initialization:
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
$$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
The default value of 1e-8 for epsilon might not be a good default in
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
formulation just before Section 2.1 of the Kingma and Ba paper rather than
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
hat" in the paper.
The sparse implementation of this algorithm (used when the gradient is an
IndexedSlices object, typically because of `tf.gather` or an embedding
lookup in the forward pass) does apply momentum to variable slices even if
they were not used in the forward pass (meaning they have a gradient equal
to zero). Momentum decay (beta1) is also applied to the entire momentum
accumulator. This means that the sparse behavior is equivalent to the dense
behavior (in contrast to some momentum implementations which ignore momentum
unless a variable slice was actually used).
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
`epsilon` can each be a callable that takes no arguments and returns the
actual value to use. This can be useful for changing these values across
different invocations of optimizer functions.
@end_compatibility
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
super(AdamOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
# Tensor versions of the constructor arguments, created in _prepare().
self._lr_t = None
self._beta1_t = None
self._beta2_t = None
self._epsilon_t = None
def __init__(
self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
use_locking=False,
name="Adam",
):
"""Construct a new Adam optimizer.
# Created in SparseApply if needed.
self._updated_lr = None
Initialization:
def _get_beta_accumulators(self):
with ops.init_scope():
if context.executing_eagerly():
graph = None
else:
graph = ops.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
$$t := 0 \text{(Initialize timestep)}$$
def _create_slots(self, var_list):
# Create the beta1 and beta2 accumulators on the same device as the first
# variable. Sort the var_list to make sure this device is consistent across
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(initial_value=self._beta1,
name="beta1_power",
colocate_with=first_var)
self._create_non_slot_variable(initial_value=self._beta2,
name="beta2_power",
colocate_with=first_var)
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)
$$t := t + 1$$
$$lr_t := \text{learning\\_rate} * \\sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
def _prepare(self):
lr = self._call_if_callable(self._lr)
beta1 = self._call_if_callable(self._beta1)
beta2 = self._call_if_callable(self._beta2)
epsilon = self._call_if_callable(self._epsilon)
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\\sqrt{v_t} + \\epsilon)$$
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
The default value of 1e-8 for epsilon might not be a good default in
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
formulation just before Section 2.1 of the Kingma and Ba paper rather than
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
hat" in the paper.
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
The sparse implementation of this algorithm (used when the gradient is an
IndexedSlices object, typically because of `tf.gather` or an embedding
lookup in the forward pass) does apply momentum to variable slices even if
they were not used in the forward pass (meaning they have a gradient equal
to zero). Momentum decay (beta1) is also applied to the entire momentum
accumulator. This means that the sparse behavior is equivalent to the dense
behavior (in contrast to some momentum implementations which ignore momentum
unless a variable slice was actually used).
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
# Clip gradients by 3 std
return training_ops.apply_adam(
var, m, v,
math_ops.cast(beta1_power, var.dtype.base_dtype),
math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
grad, use_locking=self._use_locking).op
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
def _resource_apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(
var.handle, m.handle, v.handle,
math_ops.cast(beta1_power, grad.dtype.base_dtype),
math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad, use_locking=self._use_locking)
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
`epsilon` can each be a callable that takes no arguments and returns the
actual value to use. This can be useful for changing these values across
different invocations of optimizer functions.
@end_compatibility
"""
super(AdamOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
# Tensor versions of the constructor arguments, created in _prepare().
self._lr_t = None
self._beta1_t = None
self._beta2_t = None
self._epsilon_t = None
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values, var, grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking))
# Created in SparseApply if needed.
self._updated_lr = None
def _resource_scatter_add(self, x, i, v):
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(
x.handle, i, v)]):
return x.value()
def _get_beta_accumulators(self):
with ops.init_scope():
if context.executing_eagerly():
graph = None
else:
graph = ops.get_default_graph()
return (
self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph),
)
def _resource_apply_sparse(self, grad, var, indices):
return self._apply_sparse_shared(
grad, var, indices, self._resource_scatter_add)
def _create_slots(self, var_list):
# Create the beta1 and beta2 accumulators on the same device as the first
# variable. Sort the var_list to make sure this device is consistent across
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(
initial_value=self._beta1, name="beta1_power", colocate_with=first_var
)
self._create_non_slot_variable(
initial_value=self._beta2, name="beta2_power", colocate_with=first_var
)
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
beta1_power, beta2_power = self._get_beta_accumulators()
with ops.colocate_with(beta1_power):
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t, use_locking=self._use_locking)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t, use_locking=self._use_locking)
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
name=name_scope)
# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)
def _prepare(self):
lr = self._call_if_callable(self._lr)
beta1 = self._call_if_callable(self._beta1)
beta2 = self._call_if_callable(self._beta2)
epsilon = self._call_if_callable(self._epsilon)
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
# Clip gradients by 3 std
return training_ops.apply_adam(
var,
m,
v,
math_ops.cast(beta1_power, var.dtype.base_dtype),
math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
grad,
use_locking=self._use_locking,
).op
def _resource_apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(
var.handle,
m.handle,
v.handle,
math_ops.cast(beta1_power, grad.dtype.base_dtype),
math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad,
use_locking=self._use_locking,
)
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking
)
return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values,
var,
grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking
),
)
def _resource_scatter_add(self, x, i, v):
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]
):
return x.value()
def _resource_apply_sparse(self, grad, var, indices):
return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
beta1_power, beta2_power = self._get_beta_accumulators()
with ops.colocate_with(beta1_power):
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t, use_locking=self._use_locking
)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t, use_locking=self._use_locking
)
return control_flow_ops.group(
*update_ops + [update_beta1, update_beta2], name=name_scope
)

View file

@ -1,42 +1,48 @@
from tensorflow.python.platform import flags
from tensorflow.contrib.data.python.ops import batching, threadpool
import tensorflow as tf
import json
from torch.utils.data import Dataset
import pickle
import os.path as osp
import os
import numpy as np
import os.path as osp
import pickle
import time
from scipy.misc import imread, imresize
from skimage.color import rgb2grey
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
from torchvision import transforms
from imagenet_preprocessing import ImagenetPreprocessor
import numpy as np
import tensorflow as tf
import torch
import torchvision
from imagenet_preprocessing import ImagenetPreprocessor
from scipy.misc import imread, imresize
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.platform import flags
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
FLAGS = flags.FLAGS
ROOT_DIR = "./results"
# Dataset Options
flags.DEFINE_string('dsprites_path',
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
'path to dsprites characters')
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
flags.DEFINE_string(
"dsprites_path",
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
"path to dsprites characters",
)
flags.DEFINE_string(
"imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
)
flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
flags.DEFINE_bool(
"dsprites_restrict", False, "fix all factors except for rotation of objects"
)
flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
# Data augmentation options
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
def cutout(mask_color=(0, 0, 0)):
@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
self.curr_sample = 0
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
with open(index_path) as f:
metadata = json.load(f)
counts = metadata['record_counts']
counts = metadata["record_counts"]
if split == 'train':
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
if split == "train":
file_names = list(
sorted([x for x in counts.keys() if x.startswith("train")])
)
result_records_to_skip = None
files = []
@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
# Record the number to skip in the first file
result_records_to_skip = records_to_skip
files.append(filename)
records_to_read -= (records_in_file - records_to_skip)
records_to_read -= records_in_file - records_to_skip
records_to_skip = 0
else:
break
else:
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
files = list(
sorted([x for x in counts.keys() if x.startswith("validation")])
)
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
preprocess_function = ImagenetPreprocessor(
128, dtype=tf.float32, train=False
).parse_and_preprocess
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
ds = tf.data.TFRecordDataset.from_generator(
lambda: files, output_types=tf.string
)
ds = ds.apply(tf.data.TFRecordDataset)
ds = ds.take(im_length)
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
ds = ds.apply(
batching.map_and_batch(
map_func=preprocess_function,
batch_size=FLAGS.batch_size,
num_parallel_batches=4,
)
)
ds = ds.prefetch(buffer_size=2)
ds_iterator = ds.make_initializable_iterator()
labels, images = ds_iterator.get_next()
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
self.images = tf.clip_by_value(
images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
)
self.labels = labels
config = tf.ConfigProto(device_count = {'GPU': 0})
config = tf.ConfigProto(device_count={"GPU": 0})
sess = tf.Session(config=config)
sess.run(ds_iterator.initializer)
@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
sess = self.sess
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
im_corrupt = np.random.uniform(
0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
)
label, im = sess.run([self.labels, self.images])
im = im * self.rescale
label = np.eye(1000)[label.squeeze() - 1]
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
im, im_corrupt, label = (
torch.from_numpy(im),
torch.from_numpy(im_corrupt),
torch.from_numpy(label),
)
return im_corrupt, im, label
def __iter__(self):
@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
def __len__(self):
return self.im_length
class CelebA(Dataset):
def __init__(self):
@ -180,25 +209,18 @@ class CelebA(Dataset):
im = imread(path)
im = imresize(im, (32, 32))
image_size = 32
im = im / 255.
im = im / 255.0
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0, 1, size=(image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
return im_corrupt, im, label
class Cifar10(Dataset):
def __init__(
self,
train=True,
full=False,
augment=False,
noise=True,
rescale=1.0):
def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
if augment:
transform_list = [
@ -215,16 +237,10 @@ class Cifar10(Dataset):
transform = transforms.ToTensor()
self.full = full
self.data = CIFAR10(
ROOT_DIR,
transform=transform,
train=train,
download=True)
self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
self.test_data = CIFAR10(
ROOT_DIR,
transform=transform,
train=False,
download=True)
ROOT_DIR, transform=transform, train=False, download=True
)
self.one_hot_map = np.eye(10)
self.noise = noise
self.rescale = rescale
@ -255,16 +271,18 @@ class Cifar10(Dataset):
im = im * 255 / 256
if self.noise:
im = im * self.rescale + \
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
im = im * self.rescale + np.random.uniform(
0, self.rescale * 1 / 256.0, im.shape
)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(
0.0, self.rescale, (image_size, image_size, 3))
0.0, self.rescale, (image_size, image_size, 3)
)
return im_corrupt, im, label
@ -287,10 +305,8 @@ class Cifar100(Dataset):
transform = transforms.ToTensor()
self.data = CIFAR100(
"/root/cifar100",
transform=transform,
train=train,
download=True)
"/root/cifar100", transform=transform, train=train, download=True
)
self.one_hot_map = np.eye(100)
def __len__(self):
@ -308,11 +324,10 @@ class Cifar100(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
@ -340,11 +355,10 @@ class Svhn(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
@ -352,9 +366,8 @@ class Svhn(Dataset):
class Mnist(Dataset):
def __init__(self, train=True, rescale=1.0):
self.data = MNIST(
"/root/mnist",
transform=transforms.ToTensor(),
download=True, train=train)
"/root/mnist", transform=transforms.ToTensor(), download=True, train=train
)
self.labels = np.eye(10)
self.rescale = rescale
@ -367,13 +380,13 @@ class Mnist(Dataset):
im = im.squeeze()
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
# im = im.numpy() / 2 + 0.2
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
im = im * self.rescale
image_size = 28
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
return im_corrupt, im, label
@ -381,54 +394,63 @@ class Mnist(Dataset):
class DSprites(Dataset):
def __init__(
self,
cond_size=False,
cond_shape=False,
cond_pos=False,
cond_rot=False):
self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
):
dat = np.load(FLAGS.dsprites_path)
if FLAGS.dshape_only:
l = dat['latents_values']
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
l = dat["latents_values"]
mask = (
(l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 2] == 0.5)
& (l[:, 3] == 30 * np.pi / 39)
)
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 1:2]
elif FLAGS.dpos_only:
l = dat['latents_values']
l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 1] == 1) & (
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = self.label[:, 4:] + 0.5
elif FLAGS.dsize_only:
l = dat['latents_values']
l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
self.label = (self.label[:, 2:3])
mask = (
(l[:, 3] == 30 * np.pi / 39)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 1] == 1)
)
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 2:3]
elif FLAGS.drot_only:
l = dat['latents_values']
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = (self.label[:, 3:4])
l = dat["latents_values"]
mask = (
(l[:, 2] == 0.5)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 1] == 1)
)
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = self.label[:, 3:4]
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
[np.cos(self.label), np.sin(self.label)], axis=1
)
elif FLAGS.dsprites_restrict:
l = dat['latents_values']
l = dat["latents_values"]
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
self.data = dat['imgs'][mask]
self.label = dat['latents_values'][mask]
self.data = dat["imgs"][mask]
self.label = dat["latents_values"][mask]
else:
self.data = dat['imgs']
self.label = dat['latents_values']
self.data = dat["imgs"]
self.label = dat["latents_values"]
if cond_size:
self.label = self.label[:, 2:3]
@ -439,7 +461,8 @@ class DSprites(Dataset):
elif cond_rot:
self.label = self.label[:, 3:4]
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
[np.cos(self.label), np.sin(self.label)], axis=1
)
else:
self.label = self.label[:, 1:2]
@ -452,20 +475,20 @@ class DSprites(Dataset):
im = self.data[index]
image_size = 64
if not (
FLAGS.dpos_only or FLAGS.dsize_only) and (
not FLAGS.cond_size) and (
not FLAGS.cond_pos) and (
not FLAGS.cond_rot) and (
not FLAGS.drot_only):
label = self.identity[self.label[index].astype(
np.int32) - 1].squeeze()
if (
not (FLAGS.dpos_only or FLAGS.dsize_only)
and (not FLAGS.cond_size)
and (not FLAGS.cond_pos)
and (not FLAGS.cond_rot)
and (not FLAGS.drot_only)
):
label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
else:
label = self.label[index]
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
elif FLAGS.datasource == "random":
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
return im_corrupt, im, label
@ -478,25 +501,20 @@ class Imagenet(Dataset):
for i in range(1, 11):
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'train_data_batch_{}'.format(i)),
'rb'))
osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
"rb",
)
)
if i == 1:
labels = f['labels']
data = f['data']
labels = f["labels"]
data = f["data"]
else:
labels.extend(f['labels'])
data = np.vstack((data, f['data']))
labels.extend(f["labels"])
data = np.vstack((data, f["data"]))
else:
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'val_data'),
'rb'))
labels = f['labels']
data = f['data']
f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
labels = f["labels"]
data = f["data"]
self.labels = labels
self.data = data
@ -520,11 +538,10 @@ class Imagenet(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label

View file

@ -1,68 +1,100 @@
import os
import os.path as osp
import numpy as np
import tensorflow as tf
import math
from tqdm import tqdm
from hmc import hmc
from custom_adam import AdamOptimizer
from models import DspritesNet
from scipy.misc import imsave
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader, Dataset
from models import DspritesNet
from utils import optimistic_restore, ReplayBuffer
import os.path as osp
import numpy as np
from rl_algs.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
from custom_adam import AdamOptimizer
from tqdm import tqdm
from utils import ReplayBuffer
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
flags.DEFINE_bool('cclass', True, 'not cclass')
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
flags.DEFINE_integer("batch_size", 256, "Size of inputs")
flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
flags.DEFINE_string("logdir", "cachedir", "directory for logging")
flags.DEFINE_string(
"savedir", "cachedir", "location where log of experiments will be stored"
)
flags.DEFINE_integer(
"num_filters",
64,
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
)
flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
flags.DEFINE_string(
"dsprites_path",
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
"path to dsprites characters",
)
flags.DEFINE_bool("cclass", True, "not cclass")
flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
flags.DEFINE_string(
"task",
"conceptcombine",
"conceptcombine, labeldiscover, gentest, genbaseline, etc.",
)
flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
# Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
flags.DEFINE_integer('break_steps', 300, 'steps to break')
flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
flags.DEFINE_integer("break_steps", 300, "steps to break")
# Whether to train for gentest
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
flags.DEFINE_bool(
"train",
False,
"whether to train on generalization into multiple different predictions",
)
FLAGS = flags.FLAGS
class DSpritesGen(Dataset):
def __init__(self, data, latents, frac=0.0):
l = latents
if FLAGS.joint_shape:
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
mask_size = (
(l[:, 3] == 30 * np.pi / 39)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 2] == 0.5)
)
elif FLAGS.joint_rot:
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
mask_size = (
(l[:, 1] == 1)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 2] == 0.5)
)
else:
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
mask_size = (
(l[:, 3] == 30 * np.pi / 39)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 1] == 1)
)
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
@ -80,12 +112,14 @@ class DSpritesGen(Dataset):
self.data = np.concatenate((data_pos, data_size), axis=0)
self.label = np.concatenate((l_pos, l_size), axis=0)
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
mask_neg = (~(mask_size & mask_pos)) & (
(l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
)
data_add = data[mask_neg]
l_add = l[mask_neg]
perm_idx = np.random.permutation(data_add.shape[0])
select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
data_add = data_add[select_idx]
l_add = l_add[select_idx]
@ -104,7 +138,9 @@ class DSpritesGen(Dataset):
if FLAGS.joint_shape:
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
elif FLAGS.joint_rot:
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
label_size = np.array(
[np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
)
else:
label_size = self.label[index, 2:3]
@ -114,14 +150,16 @@ class DSpritesGen(Dataset):
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
LABEL_SIZE = kvs['LABEL_SIZE']
model_size = kvs['model_size']
weight_size = kvs['weight_size']
x_mod = kvs['X_NOISE']
LABEL_SIZE = kvs["LABEL_SIZE"]
model_size = kvs["model_size"]
weight_size = kvs["weight_size"]
x_mod = kvs["X_NOISE"]
label_output = LABEL_SIZE
for i in range(FLAGS.num_steps):
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
label_output = label_output + tf.random_normal(
tf.shape(label_output), mean=0.0, stddev=0.03
)
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
label_grad = tf.gradients(e_noise, [label_output])[0]
# label_grad = tf.Print(label_grad, [label_grad])
@ -130,13 +168,13 @@ def labeldiscover(sess, kvs, data, latents, save_exp_dir):
diffs = []
for i in range(30):
s = i*FLAGS.batch_size
d = (i+1)*FLAGS.batch_size
s = i * FLAGS.batch_size
d = (i + 1) * FLAGS.batch_size
data_i = data[s:d]
latent_i = latents[s:d]
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
size_pred = sess.run([label_output], feed_dict)[0]
size_gt = latent_i[:, 2:3]
@ -155,9 +193,11 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
weights_baseline = model_baseline.construct_weights(
"context_baseline_{}".format(frac)
)
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
@ -168,14 +208,20 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
gvs = [(k, v) for (k, v) in gvs if k is not None]
train_op = optimizer.apply_gradients(gvs)
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
dataloader = DataLoader(
DSpritesGen(data, latents, frac=frac),
batch_size=FLAGS.batch_size,
num_workers=6,
drop_last=True,
shuffle=True,
)
datafull = data
itr = 0
saver = tf.train.Saver()
vs = optimizer.variables()
optimizer.variables()
sess.run(tf.global_variables_initializer())
if FLAGS.train:
@ -185,7 +231,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
data_corrupt = data_corrupt.numpy()
label_size, label_pos = label_size.numpy(), label_pos.numpy()
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
data_corrupt = np.random.randn(
data_corrupt.shape[0], 2 * FLAGS.num_filters
)
label_comb = np.concatenate([label_size, label_pos], axis=1)
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
@ -196,23 +244,27 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
itr += 1
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
l = latents
if FLAGS.joint_shape:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
else:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
)
data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen]
losses = []
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
for dat, latent in zip(
np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
):
data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
if FLAGS.joint_shape:
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
@ -220,16 +272,19 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
latent = np.concatenate([latent_size, latent_pos], axis=1)
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
else:
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
# print(loss)
losses.append(loss)
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
print(
"Overall MSE for generalization of {} for fraction of {}".format(
np.mean(losses), frac
)
)
data_try = data_gen[:10]
data_init = np.random.randn(10, 2*FLAGS.num_filters)
data_init = np.random.randn(10, 2 * FLAGS.num_filters)
if FLAGS.joint_shape:
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
@ -252,7 +307,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
-1, 66 * 2
)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
@ -261,38 +318,40 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
def gentest(sess, kvs, data, latents, save_exp_dir):
X_NOISE = kvs['X_NOISE']
LABEL_SIZE = kvs['LABEL_SIZE']
LABEL_SHAPE = kvs['LABEL_SHAPE']
LABEL_POS = kvs['LABEL_POS']
LABEL_ROT = kvs['LABEL_ROT']
model_size = kvs['model_size']
model_shape = kvs['model_shape']
model_pos = kvs['model_pos']
model_rot = kvs['model_rot']
weight_size = kvs['weight_size']
weight_shape = kvs['weight_shape']
weight_pos = kvs['weight_pos']
weight_rot = kvs['weight_rot']
X_NOISE = kvs["X_NOISE"]
LABEL_SIZE = kvs["LABEL_SIZE"]
LABEL_SHAPE = kvs["LABEL_SHAPE"]
LABEL_POS = kvs["LABEL_POS"]
LABEL_ROT = kvs["LABEL_ROT"]
model_size = kvs["model_size"]
model_shape = kvs["model_shape"]
model_pos = kvs["model_pos"]
model_rot = kvs["model_rot"]
weight_size = kvs["weight_size"]
weight_shape = kvs["weight_shape"]
weight_pos = kvs["weight_pos"]
weight_rot = kvs["weight_rot"]
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
datafull = data
# Test combination of generalization where we use slices of both training
x_final = X_NOISE
x_mod_size = X_NOISE
x_mod_pos = X_NOISE
for i in range(FLAGS.num_steps):
# use cond_pos
energies = []
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
x_mod_pos = x_mod_pos + tf.random_normal(
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
)
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
# energies.append(e_noise)
x_grad = tf.gradients(e_noise, [x_final])[0]
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
x_mod_pos = x_mod_pos + tf.random_normal(
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
)
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
@ -332,42 +391,70 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
x_mod = x_mod_pos
x_final = x_mod
if FLAGS.joint_shape:
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
loss_kl = model_shape.forward(
x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
) + model_pos.forward(
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
)
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_pos = model_shape.forward(
X, weight_shape, reuse=True, label=LABEL_SHAPE
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_shape.forward(
tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
) + model_pos.forward(
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
)
elif FLAGS.joint_rot:
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
loss_kl = model_rot.forward(
x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
) + model_pos.forward(
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
)
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_pos = model_rot.forward(
X, weight_rot, reuse=True, label=LABEL_ROT
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_rot.forward(
tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
) + model_pos.forward(
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
)
else:
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
loss_kl = model_size.forward(
x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
) + model_pos.forward(
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
)
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_pos = model_size.forward(
X, weight_size, reuse=True, label=LABEL_SIZE
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_size.forward(
tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
) + model_pos.forward(
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
)
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
neg_loss = coeff * (-1*energy_neg) / norm_constant
coeff * (-1 * energy_neg) / norm_constant
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
loss_total = (
loss_ml
+ tf.reduce_mean(loss_kl)
+ 1
* (
tf.reduce_mean(tf.square(energy_pos))
+ tf.reduce_mean(tf.square(energy_neg))
)
)
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
gvs = optimizer.compute_gradients(loss_total)
@ -377,7 +464,13 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
vs = optimizer.variables()
sess.run(tf.variables_initializer(vs))
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
dataloader = DataLoader(
DSpritesGen(data, latents),
batch_size=FLAGS.batch_size,
num_workers=6,
drop_last=True,
shuffle=True,
)
x_off = tf.reduce_mean(tf.square(x_mod - X))
@ -385,12 +478,10 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
saver = tf.train.Saver()
x_mod = None
if FLAGS.train:
replay_buffer = ReplayBuffer(10000)
for _ in range(1):
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
data_corrupt = data_corrupt.numpy()[:, :, :]
data = data.numpy()[:, :, :]
@ -398,29 +489,50 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
if x_mod is not None:
replay_buffer.add(x_mod)
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.joint_shape:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
feed_dict = {
X_NOISE: data_corrupt,
X: data,
LABEL_SHAPE: label_size,
LABEL_POS: label_pos,
}
elif FLAGS.joint_rot:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
feed_dict = {
X_NOISE: data_corrupt,
X: data,
LABEL_ROT: label_size,
LABEL_POS: label_pos,
}
else:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
feed_dict = {
X_NOISE: data_corrupt,
X: data,
LABEL_SIZE: label_size,
LABEL_POS: label_pos,
}
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
_, off_value, e_pos, e_neg, x_mod = sess.run(
[train_op, x_off, energy_pos, energy_neg, x_final],
feed_dict=feed_dict,
)
itr += 1
if itr % 10 == 0:
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
print(
"x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
off_value, e_pos.mean(), e_neg.mean(), itr
)
)
if itr == FLAGS.break_steps:
break
saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
l = latents
@ -429,22 +541,43 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
elif FLAGS.joint_rot:
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
else:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
)
data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen]
losses = []
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
for dat, latent in zip(
np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
):
x = 0.5 + np.random.randn(*dat.shape)
if FLAGS.joint_shape:
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
feed_dict = {
LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
LABEL_POS: latent[:, 4:],
X_NOISE: x,
X: dat,
}
elif FLAGS.joint_rot:
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
feed_dict = {
LABEL_ROT: np.concatenate(
[np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
),
LABEL_POS: latent[:, 4:],
X_NOISE: x,
X: dat,
}
else:
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
feed_dict = {
LABEL_SIZE: latent[:, 2:3],
LABEL_POS: latent[:, 4:],
X_NOISE: x,
X: dat,
}
for i in range(2):
x = sess.run([x_final], feed_dict=feed_dict)[0]
@ -461,11 +594,25 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
latent_pos = latents_gen[:10, 4:]
if FLAGS.joint_shape:
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
feed_dict = {
X_NOISE: data_init,
LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
LABEL_POS: latent_pos,
}
elif FLAGS.joint_rot:
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
feed_dict = {
LABEL_ROT: np.concatenate(
[np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
),
LABEL_POS: latent[:10, 4:],
X_NOISE: data_init,
}
else:
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
feed_dict = {
X_NOISE: data_init,
LABEL_SIZE: latent_scale,
LABEL_POS: latent_pos,
}
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
@ -480,27 +627,28 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
-1, 66 * 2
)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
X_NOISE = kvs['X_NOISE']
LABEL_SIZE = kvs['LABEL_SIZE']
LABEL_SHAPE = kvs['LABEL_SHAPE']
LABEL_POS = kvs['LABEL_POS']
LABEL_ROT = kvs['LABEL_ROT']
model_size = kvs['model_size']
model_shape = kvs['model_shape']
model_pos = kvs['model_pos']
model_rot = kvs['model_rot']
weight_size = kvs['weight_size']
weight_shape = kvs['weight_shape']
weight_pos = kvs['weight_pos']
weight_rot = kvs['weight_rot']
X_NOISE = kvs["X_NOISE"]
LABEL_SIZE = kvs["LABEL_SIZE"]
LABEL_SHAPE = kvs["LABEL_SHAPE"]
LABEL_POS = kvs["LABEL_POS"]
LABEL_ROT = kvs["LABEL_ROT"]
model_size = kvs["model_size"]
model_shape = kvs["model_shape"]
model_pos = kvs["model_pos"]
model_rot = kvs["model_rot"]
weight_size = kvs["weight_size"]
weight_shape = kvs["weight_shape"]
weight_pos = kvs["weight_pos"]
weight_rot = kvs["weight_rot"]
x_mod = X_NOISE
for i in range(FLAGS.num_steps):
@ -540,13 +688,18 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
data_try = data[:10]
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
label_scale = latents[:10, 2:3]
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
label_rot = latents[:10, 3:4]
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
label_pos = latents[:10, 4:]
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
LABEL_ROT: label_rot}
feed_dict = {
X_NOISE: data_init,
LABEL_SIZE: label_scale,
LABEL_SHAPE: label_shape,
LABEL_POS: label_pos,
LABEL_ROT: label_rot,
}
x_out = sess.run([x_final], feed_dict)[0]
im_name = "im"
@ -569,14 +722,15 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
x_out_pad[:, 1:-1, 1:-1] = x_out
data_try_pad[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
def main():
data = np.load(FLAGS.dsprites_path)['imgs']
l = latents = np.load(FLAGS.dsprites_path)['latents_values']
data = np.load(FLAGS.dsprites_path)["imgs"]
l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
np.random.seed(1)
idx = np.random.permutation(data.shape[0])
@ -589,52 +743,74 @@ def main():
# Model 1 will be conditioned on size
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
weight_size = model_size.construct_weights('context_0')
weight_size = model_size.construct_weights("context_0")
# Model 2 will be conditioned on shape
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
weight_shape = model_shape.construct_weights('context_1')
weight_shape = model_shape.construct_weights("context_1")
# Model 3 will be conditioned on position
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
weight_pos = model_pos.construct_weights('context_2')
weight_pos = model_pos.construct_weights("context_2")
# Model 4 will be conditioned on rotation
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
weight_rot = model_rot.construct_weights('context_3')
weight_rot = model_rot.construct_weights("context_3")
sess.run(tf.global_variables_initializer())
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
save_path_size = osp.join(
FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
)
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
)
v_map = {
(v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
}
if FLAGS.cond_scale:
saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_size)
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
save_path_shape = osp.join(
FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
)
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
)
v_map = {
(v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
}
if FLAGS.cond_shape:
saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_shape)
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
save_path_pos = osp.join(
FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
)
v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
)
v_map = {
(v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
}
saver = tf.train.Saver(v_map)
if FLAGS.cond_pos:
saver.restore(sess, save_path_pos)
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
save_path_rot = osp.join(
FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
)
v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
)
v_map = {
(v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
}
saver = tf.train.Saver(v_map)
if FLAGS.cond_rot:
@ -646,53 +822,57 @@ def main():
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
x_mod = X_NOISE
kvs = {}
kvs['X_NOISE'] = X_NOISE
kvs['LABEL_SIZE'] = LABEL_SIZE
kvs['LABEL_SHAPE'] = LABEL_SHAPE
kvs['LABEL_POS'] = LABEL_POS
kvs['LABEL_ROT'] = LABEL_ROT
kvs['model_size'] = model_size
kvs['model_shape'] = model_shape
kvs['model_pos'] = model_pos
kvs['model_rot'] = model_rot
kvs['weight_size'] = weight_size
kvs['weight_shape'] = weight_shape
kvs['weight_pos'] = weight_pos
kvs['weight_rot'] = weight_rot
kvs["X_NOISE"] = X_NOISE
kvs["LABEL_SIZE"] = LABEL_SIZE
kvs["LABEL_SHAPE"] = LABEL_SHAPE
kvs["LABEL_POS"] = LABEL_POS
kvs["LABEL_ROT"] = LABEL_ROT
kvs["model_size"] = model_size
kvs["model_shape"] = model_shape
kvs["model_pos"] = model_pos
kvs["model_rot"] = model_rot
kvs["weight_size"] = weight_size
kvs["weight_shape"] = weight_shape
kvs["weight_pos"] = weight_pos
kvs["weight_rot"] = weight_rot
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
save_exp_dir = osp.join(
FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
)
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
if FLAGS.task == 'conceptcombine':
if FLAGS.task == "conceptcombine":
conceptcombine(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'labeldiscover':
elif FLAGS.task == "labeldiscover":
labeldiscover(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'gentest':
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
elif FLAGS.task == "gentest":
save_exp_dir = osp.join(
FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
)
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
gentest(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'genbaseline':
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
elif FLAGS.task == "genbaseline":
save_exp_dir = osp.join(
FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
)
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
if FLAGS.plot_curve:
mse_losses = []
for frac in [i/10 for i in range(11)]:
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
for frac in [i / 10 for i in range(11)]:
mse_loss = genbaseline(
sess, kvs, data, latents, save_exp_dir, frac=frac
)
mse_losses.append(mse_loss)
np.save("mse_baseline_comb.npy", mse_losses)
else:
genbaseline(sess, kvs, data, latents, save_exp_dir)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python3
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs.
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
samples respectivly.
See --help to see further details.
'''
"""
from __future__ import absolute_import, division, print_function
import numpy as np
import os
import gzip, pickle
import tensorflow as tf
from scipy.misc import imread
from scipy import linalg
import pathlib
import urllib
import tarfile
import urllib
import warnings
MODEL_DIR = '/tmp/imagenet'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
import numpy as np
import tensorflow as tf
from scipy import linalg
from scipy.misc import imread
MODEL_DIR = "/tmp/imagenet"
DATA_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
pool3 = None
class InvalidFIDException(Exception):
pass
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def get_fid_score(images, images_gt):
images = np.stack(images, 0)
images_gt = np.stack(images_gt, 0)
@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile( pth, 'rb') as f:
with tf.gfile.FastGFile(pth, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString( f.read())
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
#-------------------------------------------------------------------------------
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
# -------------------------------------------------------------------------------
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """
layername = 'FID_Inception_Net/pool_3:0'
"""Prepares inception net for batched usage and returns pool_3 layer."""
layername = "FID_Inception_Net/pool_3:0"
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=50, verbose=False):
@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
# inception_layer = _get_inception_layer(sess)
d0 = images.shape[0]
if batch_size > d0:
print("warning: batch size is bigger than the data size. setting batch size to data size")
print(
"warning: batch size is bigger than the data size. setting batch size to data size"
)
batch_size = d0
n_batches = d0//batch_size
n_used_imgs = n_batches*batch_size
pred_arr = np.empty((n_used_imgs,2048))
n_batches = d0 // batch_size
n_used_imgs = n_batches * batch_size
pred_arr = np.empty((n_used_imgs, 2048))
for i in range(n_batches):
if verbose:
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
start = i*batch_size
print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
start = i * batch_size
end = start + batch_size
batch = images[start:end]
pred = sess.run(pool3, {'ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size,-1)
pred = sess.run(pool3, {"ExpandDims:0": batch})
pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose:
print(" done")
return pred_arr
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# The following functions aren't needed for calculating the FID
# they're just here to make this module work as a stand-alone script
# for calculating FID scores
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads
the file if it is not present. '''
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
"""Checks if the path to the inception file is valid, or downloads
the file if it is not present."""
INCEPTION_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
if inception_path is None:
inception_path = '/tmp'
inception_path = "/tmp"
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
model_file = inception_path / "classify_image_graph_def.pb"
if not model_file.exists():
print("Downloading Inception model")
from urllib import request
import tarfile
from urllib import request
fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
with tarfile.open(fn, mode="r") as f:
f.extract("classify_image_graph_def.pb", str(model_file.parent))
return str(model_file)
def _handle_path(path, sess):
if path.endswith('.npz'):
if path.endswith(".npz"):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
m, s = f["mu"][:], f["sigma"][:]
f.close()
else:
path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
m, s = calculate_activation_statistics(x, sess)
return m, s
def calculate_fid_given_paths(paths, inception_path):
''' Calculates the FID of two paths. '''
"""Calculates the FID of two paths."""
inception_path = check_or_download_inception(inception_path)
for p in paths:
@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
def _init_inception():
global pool3
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
with tf.gfile.FastGFile(os.path.join(
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size.
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
global pool3
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split("/")[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
"\r>> Downloading %s %.1f%%"
% (filename, float(count * block_size) / float(total_size) * 100.0)
)
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
with tf.gfile.FastGFile(
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
) as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
# Works with an arbitrary minibatch size.
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
if pool3 is None:
_init_inception()
_init_inception()

View file

@ -1,11 +1,11 @@
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
flags.DEFINE_bool("proposal_debug", False, "Print hmc acceptance raes")
FLAGS = flags.FLAGS
def kinetic_energy(velocity):
"""Kinetic energy of the current velocity (assuming a standard Gaussian)
(x dot x) / 2
@ -21,6 +21,7 @@ def kinetic_energy(velocity):
"""
return 0.5 * tf.square(velocity)
def hamiltonian(position, velocity, energy_function):
"""Computes the Hamiltonian of the current position, velocity pair
@ -44,13 +45,12 @@ def hamiltonian(position, velocity, energy_function):
"""
batch_size = tf.shape(velocity)[0]
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
return tf.squeeze(energy_function(position)) + tf.reduce_sum(
kinetic_energy_flat, axis=[1]
)
def leapfrog_step(x0,
v0,
neg_log_posterior,
step_size,
num_steps):
def leapfrog_step(x0, v0, neg_log_posterior, step_size, num_steps):
# Start by updating the velocity a half-step
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
@ -83,10 +83,8 @@ def leapfrog_step(x0,
# return new proposal state
return x, v
def hmc(initial_x,
step_size,
num_steps,
neg_log_posterior):
def hmc(initial_x, step_size, num_steps, neg_log_posterior):
"""Summary
Parameters
@ -107,11 +105,13 @@ def hmc(initial_x,
"""
v0 = tf.random_normal(tf.shape(initial_x))
x, v = leapfrog_step(initial_x,
v0,
step_size=step_size,
num_steps=num_steps,
neg_log_posterior=neg_log_posterior)
x, v = leapfrog_step(
initial_x,
v0,
step_size=step_size,
num_steps=num_steps,
neg_log_posterior=neg_log_posterior,
)
orig = hamiltonian(initial_x, v0, neg_log_posterior)
current = hamiltonian(x, v, neg_log_posterior)
@ -119,10 +119,12 @@ def hmc(initial_x,
prob_accept = tf.exp(orig - current)
if FLAGS.proposal_debug:
prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
prob_accept = tf.Print(
prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]
)
uniform = tf.random_uniform(tf.shape(prob_accept))
keep_mask = (prob_accept > uniform)
keep_mask = prob_accept > uniform
# print(keep_mask.get_shape())
x_new = tf.where(keep_mask, x, initial_x)

View file

@ -1,24 +1,28 @@
from models import ResNet128
import numpy as np
import os.path as osp
from tensorflow.python.platform import flags
import tensorflow as tf
import imageio
from utils import optimistic_restore
import numpy as np
import tensorflow as tf
from models import ResNet128
from tensorflow.python.platform import flags
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
flags.DEFINE_bool('cclass', True, 'conditional models')
flags.DEFINE_bool('use_attention', False, 'using attention')
flags.DEFINE_string(
"logdir", "cachedir", "location where log of experiments will be stored"
)
flags.DEFINE_integer("num_steps", 200, "num of steps for conditional imagenet sampling")
flags.DEFINE_float("step_lr", 180.0, "step size for Langevin dynamics")
flags.DEFINE_integer("batch_size", 16, "number of steps to run")
flags.DEFINE_string("exp", "default", "name of experiments")
flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
flags.DEFINE_bool(
"spec_norm", True, "whether to use spectral normalization in weights in a model"
)
flags.DEFINE_bool("cclass", True, "conditional models")
flags.DEFINE_bool("use_attention", False, "using attention")
FLAGS = flags.FLAGS
def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8)
@ -32,12 +36,11 @@ if __name__ == "__main__":
weights = model.construct_weights("context_0")
x_mod = X_NOISE
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
mean=0.0,
stddev=0.005)
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
reuse=True, stop_at_grad=False, stop_batch=True)
energy_noise = energy_start = model.forward(
x_mod, weights, label=LABEL, reuse=True, stop_at_grad=False, stop_batch=True
)
x_grad = tf.gradients(energy_noise, [x_mod])[0]
energy_noise_old = energy_noise
@ -54,20 +57,23 @@ if __name__ == "__main__":
saver = loader = tf.train.Saver()
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
saver.restore(sess, model_file)
lx = np.random.permutation(1000)[:16]
ims = []
# What to initialize sampling with.
# What to initialize sampling with.
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
labels = np.eye(1000)[lx]
for i in range(FLAGS.num_steps):
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
imageio.mimwrite('sample.gif', ims)
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE: x_mod, LABEL: labels})
ims.append(
rescale_im(x_mod)
.reshape((4, 4, 128, 128, 3))
.transpose((0, 2, 1, 3, 4))
.reshape((512, 512, 3))
)
imageio.mimwrite("sample.gif", ims)

View file

@ -13,14 +13,11 @@
# limitations under the License.
# ==============================================================================
"""Image pre-processing utilities.
"""
"""Image pre-processing utilities."""
import tensorflow as tf
IMAGE_DEPTH = 3 # color images
IMAGE_DEPTH = 3 # color images
import tensorflow as tf
# _R_MEAN = 123.68
# _G_MEAN = 116.78
@ -35,303 +32,318 @@ _RESIZE_MIN = 128
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
"""Crops the given image to a random part of the image, and randomly flips.
"""Crops the given image to a random part of the image, and randomly flips.
We use the fused decode_and_crop op, which performs better than the two ops
used separately in series, but note that this requires that the image be
passed in as an un-decoded string Tensor.
We use the fused decode_and_crop op, which performs better than the two ops
used separately in series, but note that this requires that the image be
passed in as an un-decoded string Tensor.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding.
Returns:
3-D tensor with cropped image.
Returns:
3-D tensor with cropped image.
"""
# A large fraction of image datasets contain a human-annotated bounding box
# delineating the region of the image containing the object of interest. We
# choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.image.extract_jpeg_shape(image_buffer),
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0],
max_attempts=100,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
"""
# A large fraction of image datasets contain a human-annotated bounding box
# delineating the region of the image containing the object of interest. We
# choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.image.extract_jpeg_shape(image_buffer),
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0],
max_attempts=100,
use_image_if_no_bounding_boxes=True,
)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Reassemble the bounding box in the format the crop op requires.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
# Reassemble the bounding box in the format the crop op requires.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
# Use the fused decode and crop op here, which is faster than each in series.
cropped = tf.image.decode_and_crop_jpeg(
image_buffer, crop_window, channels=num_channels)
# Use the fused decode and crop op here, which is faster than each in
# series.
cropped = tf.image.decode_and_crop_jpeg(
image_buffer, crop_window, channels=num_channels
)
# Flip to add a little more random distortion in.
cropped = tf.image.random_flip_left_right(cropped)
return cropped
# Flip to add a little more random distortion in.
cropped = tf.image.random_flip_left_right(cropped)
return cropped
def _central_crop(image, crop_height, crop_width):
"""Performs central crops of the given image list.
"""Performs central crops of the given image list.
Args:
image: a 3-D image tensor
crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop.
Args:
image: a 3-D image tensor
crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop.
Returns:
3-D tensor with cropped image.
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]
Returns:
3-D tensor with cropped image.
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]
amount_to_be_cropped_h = (height - crop_height)
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width)
crop_left = amount_to_be_cropped_w // 2
return tf.slice(
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
amount_to_be_cropped_h = height - crop_height
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = width - crop_width
crop_left = amount_to_be_cropped_w // 2
return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels):
"""Subtracts the given means from each image channel.
"""Subtracts the given means from each image channel.
For example:
means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means)
For example:
means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means)
Note that the rank of `image` must be known.
Note that the rank of `image` must be known.
Args:
image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
num_channels: number of color channels in the image that will be distorted.
Args:
image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
num_channels: number of color channels in the image that will be distorted.
Returns:
the centered image.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if image.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if image.get_shape().ndims != 3:
raise ValueError("Input must be of size [height, width, C>0]")
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
if len(means) != num_channels:
raise ValueError("len(means) must match the number of channels")
# We have a 1-D tensor of means; convert to 3-D.
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
# We have a 1-D tensor of means; convert to 3-D.
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
return image - means
return image - means
def _smallest_size_at_least(height, width, resize_min):
"""Computes new shape with the smallest side equal to `smallest_side`.
"""Computes new shape with the smallest side equal to `smallest_side`.
Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio.
Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio.
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
new_height: an int32 scalar tensor indicating the new height.
new_width: an int32 scalar tensor indicating the new width.
"""
resize_min = tf.cast(resize_min, tf.float32)
Returns:
new_height: an int32 scalar tensor indicating the new height.
new_width: an int32 scalar tensor indicating the new width.
"""
resize_min = tf.cast(resize_min, tf.float32)
# Convert to floats to make subsequent calculations go smoothly.
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
# Convert to floats to make subsequent calculations go smoothly.
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
smaller_dim = tf.minimum(height, width)
scale_ratio = resize_min / smaller_dim
smaller_dim = tf.minimum(height, width)
scale_ratio = resize_min / smaller_dim
# Convert back to ints to make heights and widths that TF ops will accept.
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
# Convert back to ints to make heights and widths that TF ops will accept.
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
return new_height, new_width
return new_height, new_width
def _aspect_preserving_resize(image, resize_min):
"""Resize images preserving the original aspect ratio.
"""Resize images preserving the original aspect ratio.
Args:
image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Args:
image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
return _resize_image(image, new_height, new_width)
return _resize_image(image, new_height, new_width)
def _resize_image(image, height, width):
"""Simple wrapper around tf.resize_images.
"""Simple wrapper around tf.resize_images.
This is primarily to make sure we use the same `ResizeMethod` and other
details each time.
This is primarily to make sure we use the same `ResizeMethod` and other
details each time.
Args:
image: A 3-D image `Tensor`.
height: The target height for the resized image.
width: The target width for the resized image.
Args:
image: A 3-D image `Tensor`.
height: The target height for the resized image.
width: The target width for the resized image.
Returns:
resized_image: A 3-D tensor containing the resized image. The first two
dimensions have the shape [height, width].
"""
return tf.image.resize_images(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
align_corners=False)
Returns:
resized_image: A 3-D tensor containing the resized image. The first two
dimensions have the shape [height, width].
"""
return tf.image.resize_images(
image,
[height, width],
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False,
)
def preprocess_image(image_buffer, bbox, output_height, output_width,
num_channels, is_training=False):
"""Preprocesses the given image.
def preprocess_image(
image_buffer, bbox, output_height, output_width, num_channels, is_training=False
):
"""Preprocesses the given image.
Preprocessing includes decoding, cropping, and resizing for both training
and eval images. Training preprocessing, however, introduces some random
distortion of the image to improve accuracy.
Preprocessing includes decoding, cropping, and resizing for both training
and eval images. Training preprocessing, however, introduces some random
distortion of the image to improve accuracy.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding.
is_training: `True` if we're preprocessing the image for training and
`False` otherwise.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding.
is_training: `True` if we're preprocessing the image for training and
`False` otherwise.
Returns:
A preprocessed image.
"""
if is_training:
# For training, we want to randomize some of the distortions.
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
image = _resize_image(image, output_height, output_width)
else:
# For validation, we want to decode, resize, then just crop the middle.
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
image = _aspect_preserving_resize(image, _RESIZE_MIN)
print(image)
image = _central_crop(image, output_height, output_width)
Returns:
A preprocessed image.
"""
if is_training:
# For training, we want to randomize some of the distortions.
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
image = _resize_image(image, output_height, output_width)
else:
# For validation, we want to decode, resize, then just crop the middle.
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
image = _aspect_preserving_resize(image, _RESIZE_MIN)
print(image)
image = _central_crop(image, output_height, output_width)
image.set_shape([output_height, output_width, num_channels])
image.set_shape([output_height, output_width, num_channels])
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image.
"""Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains
the following fields:
The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains
the following fields:
image/height: 462
image/width: 581
image/colorspace: 'RGB'
image/channels: 3
image/class/label: 615
image/class/synset: 'n03623198'
image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6
image/object/bbox/label: 615
image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string>
image/height: 462
image/width: 581
image/colorspace: 'RGB'
image/channels: 3
image/class/label: 615
image/class/synset: 'n03623198'
image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6
image/object/bbox/label: 615
image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string>
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file.
label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
text: Tensor tf.string containing the human-readable label.
"""
# Dense features in Example proto.
feature_map = {
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
default_value=-1),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
}
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update(
{k: sparse_float32 for k in ['image/object/bbox/xmin',
'image/object/bbox/ymin',
'image/object/bbox/xmax',
'image/object/bbox/ymax']})
Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file.
label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
text: Tensor tf.string containing the human-readable label.
"""
# Dense features in Example proto.
feature_map = {
"image/encoded": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
"image/class/label": tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1),
"image/class/text": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
}
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update(
{
k: sparse_float32
for k in [
"image/object/bbox/xmin",
"image/object/bbox/ymin",
"image/object/bbox/xmax",
"image/object/bbox/ymax",
]
}
)
features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features["image/class/label"], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
xmin = tf.expand_dims(features["image/object/bbox/xmin"].values, 0)
ymin = tf.expand_dims(features["image/object/bbox/ymin"].values, 0)
xmax = tf.expand_dims(features["image/object/bbox/xmax"].values, 0)
ymax = tf.expand_dims(features["image/object/bbox/ymax"].values, 0)
# Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
# Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
# Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1])
# Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1])
return features['image/encoded'], label, bbox, features['image/class/text']
return features["image/encoded"], label, bbox, features["image/class/text"]
class ImagenetPreprocessor:
def __init__(self, image_size, dtype, train):
self.image_size = image_size
self.dtype = dtype
self.train = train
def __init__(self, image_size, dtype, train):
self.image_size = image_size
self.dtype = dtype
self.train = train
def preprocess(self, image_buffer, bbox):
# pylint: disable=g-import-not-at-top
image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
return tf.cast(image, self.dtype)
def parse_and_preprocess(self, value):
image_buffer, label_index, bbox, _ = parse_example_proto(value)
image = self.preprocess(image_buffer, bbox)
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
return label_index, image
def preprocess(self, image_buffer, bbox):
# pylint: disable=g-import-not-at-top
image = preprocess_image(
image_buffer,
bbox,
self.image_size,
self.image_size,
IMAGE_DEPTH,
is_training=self.train,
)
return tf.cast(image, self.dtype)
def parse_and_preprocess(self, value):
image_buffer, label_index, bbox, _ = parse_example_proto(value)
image = self.preprocess(image_buffer, bbox)
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
return label_index, image

View file

@ -1,105 +1,112 @@
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Code derived from
# tensorflow/tensorflow/models/image/imagenet/classify_image.py
from __future__ import absolute_import, division, print_function
import math
import os.path
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
import glob
import scipy.misc
import math
import sys
import horovod.tensorflow as hvd
import numpy as np
import tensorflow as tf
from six.moves import urllib
MODEL_DIR = '/tmp/imagenet'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
MODEL_DIR = "/tmp/imagenet"
DATA_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
softmax = None
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
sess = tf.Session(config=config)
# Call this function with list of images. Each of elements should be a
# Call this function with list of images. Each of elements should be a
# numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10):
# For convenience
if len(images[0].shape) != 3:
return 0, 0
# For convenience
if len(images[0].shape) != 3:
return 0, 0
# Bypassing all the assertions so that we don't end prematuraly'
# assert(type(images) == list)
# assert(type(images[0]) == np.ndarray)
# assert(len(images[0].shape) == 3)
# assert(np.max(images[0]) > 10)
# assert(np.min(images[0]) >= 0.0)
inps = []
for img in images:
img = img.astype(np.float32)
inps.append(np.expand_dims(img, 0))
bs = 1
preds = []
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
for i in range(n_batches):
sys.stdout.write(".")
sys.stdout.flush()
inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
inp = np.concatenate(inp, 0)
pred = sess.run(softmax, {"ExpandDims:0": inp})
preds.append(pred)
preds = np.concatenate(preds, 0)
scores = []
for i in range(splits):
part = preds[
(i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
# Bypassing all the assertions so that we don't end prematuraly'
# assert(type(images) == list)
# assert(type(images[0]) == np.ndarray)
# assert(len(images[0].shape) == 3)
# assert(np.max(images[0]) > 10)
# assert(np.min(images[0]) >= 0.0)
inps = []
for img in images:
img = img.astype(np.float32)
inps.append(np.expand_dims(img, 0))
bs = 1
preds = []
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
for i in range(n_batches):
sys.stdout.write(".")
sys.stdout.flush()
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
inp = np.concatenate(inp, 0)
pred = sess.run(softmax, {'ExpandDims:0': inp})
preds.append(pred)
preds = np.concatenate(preds, 0)
scores = []
for i in range(splits):
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
# This function is called automatically.
def _init_inception():
global softmax
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
with tf.gfile.FastGFile(os.path.join(
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size.
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.set_shape(tf.TensorShape(new_shape))
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax = tf.nn.softmax(logits)
global softmax
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split("/")[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
"\r>> Downloading %s %.1f%%"
% (filename, float(count * block_size) / float(total_size) * 100.0)
)
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
with tf.gfile.FastGFile(
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
) as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
# Works with an arbitrary minibatch size.
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.set_shape(tf.TensorShape(new_shape))
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax = tf.nn.softmax(logits)
if softmax is None:
_init_inception()
_init_inception()

File diff suppressed because it is too large Load diff

View file

@ -1,53 +1,56 @@
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
import os.path as osp
import os
from utils import optimistic_restore, remap_restore, optimistic_remap_restore
from tqdm import tqdm
import random
from scipy.misc import imsave
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
from torch.utils.data import DataLoader
from baselines.common.tf_util import initialize
import horovod.tensorflow as hvd
import numpy as np
import tensorflow as tf
from baselines.common.tf_util import initialize
from data import Cifar10, Imagenet, TFImagenetLoader
from fid import get_fid_score
from inception import get_inception_score
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
from scipy.misc import imsave
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import optimistic_remap_restore
hvd.init()
from inception import get_inception_score
from fid import get_fid_score
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_bool('cclass', False, 'whether to condition on class')
flags.DEFINE_string(
"logdir", "cachedir", "location where log of experiments will be stored"
)
flags.DEFINE_string("exp", "default", "name of experiments")
flags.DEFINE_bool("cclass", False, "whether to condition on class")
# Architecture settings
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
flags.DEFINE_integer('batch_size', 512, 'batch size')
flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
flags.DEFINE_integer('idx', 0, 'save index')
flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
flags.DEFINE_bool('single', False, 'single ')
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
flags.DEFINE_float("step_lr", 10.0, "Size of steps for gradient descent")
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
flags.DEFINE_float("proj_norm", 0.05, "Maximum change of input images")
flags.DEFINE_integer("batch_size", 512, "batch size")
flags.DEFINE_integer("resume_iter", -1, "resume iteration")
flags.DEFINE_integer("ensemble", 10, "number of ensembles")
flags.DEFINE_integer("im_number", 50000, "number of ensembles")
flags.DEFINE_integer("repeat_scale", 100, "number of repeat iterations")
flags.DEFINE_float("noise_scale", 0.005, "amount of noise to output")
flags.DEFINE_integer("idx", 0, "save index")
flags.DEFINE_integer("nomix", 10, "number of intervals to stop mixing")
flags.DEFINE_bool("scaled", True, "whether to scale noise added")
flags.DEFINE_bool("large_model", False, "whether to use a small or large model")
flags.DEFINE_bool("larger_model", False, "Whether to use a large model")
flags.DEFINE_bool("wider_model", False, "Whether to use a large model")
flags.DEFINE_bool("single", False, "single ")
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
flags.DEFINE_string("dataset", "cifar10", "cifar10 or imagenet or imagenetfull")
FLAGS = flags.FLAGS
class InceptionReplayBuffer(object):
def __init__(self, size):
"""Create Replay buffer.
@ -72,14 +75,16 @@ class InceptionReplayBuffer(object):
self._label_storage.extend(list(labels))
else:
if batch_size + self._next_idx < self._maxsize:
self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
self._label_storage[self._next_idx : self._next_idx + batch_size] = (
list(labels)
)
else:
split_idx = self._maxsize - self._next_idx
self._storage[self._next_idx:] = list(ims)[:split_idx]
self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
self._label_storage[self._next_idx:] = list(labels)[:split_idx]
self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
self._storage[self._next_idx :] = list(ims)[:split_idx]
self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
self._label_storage[self._next_idx :] = list(labels)[:split_idx]
self._label_storage[: batch_size - split_idx] = list(labels)[split_idx:]
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
@ -123,12 +128,13 @@ class InceptionReplayBuffer(object):
def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8)
def compute_inception(sess, target_vars):
X_START = target_vars['X_START']
Y_GT = target_vars['Y_GT']
X_finals = target_vars['X_finals']
NOISE_SCALE = target_vars['NOISE_SCALE']
energy_noise = target_vars['energy_noise']
X_START = target_vars["X_START"]
Y_GT = target_vars["Y_GT"]
X_finals = target_vars["X_finals"]
NOISE_SCALE = target_vars["NOISE_SCALE"]
energy_noise = target_vars["energy_noise"]
size = FLAGS.im_number
num_steps = size // 1000
@ -136,16 +142,21 @@ def compute_inception(sess, target_vars):
images = []
test_ims = []
if FLAGS.dataset == "cifar10":
test_dataset = Cifar10(full=True, noise=False)
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
test_dataset = Imagenet(train=False)
if FLAGS.dataset != "imagenetfull":
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
test_dataloader = DataLoader(
test_dataset,
batch_size=FLAGS.batch_size,
num_workers=4,
shuffle=True,
drop_last=False,
)
else:
test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
test_dataloader = TFImagenetLoader("test", FLAGS.batch_size, 0, 1)
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data = data.numpy()
@ -155,7 +166,6 @@ def compute_inception(sess, target_vars):
test_ims = test_ims[:60000]
break
# n = min(len(images), len(test_ims))
print(len(test_ims))
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
@ -187,12 +197,14 @@ def compute_inception(sess, target_vars):
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
label = np.random.randint(0, classes, (FLAGS.batch_size))
label = identity[label]
x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
x_new = sess.run(
[x_final], {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale}
)[0]
data_buffer.add(x_new, label)
else:
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99
label_keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
label_corrupt = identity[label_corrupt]
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
@ -203,7 +215,10 @@ def compute_inception(sess, target_vars):
# else:
# noise_scale = [0.7]
x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
x_new, e_noise = sess.run(
[x_final, energy_noise],
{X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale},
)
data_buffer.set_elms(idx, x_new, label)
if FLAGS.im_number != 50000:
@ -216,14 +231,22 @@ def compute_inception(sess, target_vars):
images.extend(list(ims))
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
saveim = osp.join("sandbox_cachedir", FLAGS.exp, "test{}.png".format(FLAGS.idx))
ims = ims[:100]
if FLAGS.dataset != "imagenetfull":
im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
im_panel = (
ims.reshape((10, 10, 32, 32, 3))
.transpose((0, 2, 1, 3, 4))
.reshape((320, 320, 3))
)
else:
im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
im_panel = (
ims.reshape((10, 10, 128, 128, 3))
.transpose((0, 2, 1, 3, 4))
.reshape((1280, 1280, 3))
)
imsave(saveim, im_panel)
print("Saved image!!!!")
@ -237,8 +260,6 @@ def compute_inception(sess, target_vars):
print("FID of score {}".format(fid))
def main(model_list):
if FLAGS.dataset == "imagenetfull":
@ -259,45 +280,55 @@ def main(model_list):
weights = []
for i, model_num in enumerate(model_list):
weight = model.construct_weights('context_{}'.format(i))
weight = model.construct_weights("context_{}".format(i))
initialize()
save_file = osp.join(logdir, 'model_{}'.format(model_num))
save_file = osp.join(logdir, "model_{}".format(model_num))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(i)
)
v_map = {
(v.name.replace("context_{}".format(i), "context_0")[:-2]): v
for v in v_list
}
saver = tf.train.Saver(v_map)
try:
saver.restore(sess, save_file)
except:
except BaseException:
optimistic_remap_restore(sess, save_file, i)
weights.append(weight)
if FLAGS.dataset == "imagenetfull":
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
else:
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
if FLAGS.dataset == "cifar10":
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
else:
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
X_finals = []
# Seperate loops
for weight in weights:
X = X_START
steps = tf.constant(0)
c = lambda i, x: tf.less(i, FLAGS.num_steps)
def c(i, x):
return tf.less(i, FLAGS.num_steps)
def langevin_step(counter, X):
scale_rate = 1
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
X = X + tf.random_normal(
tf.shape(X),
mean=0.0,
stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE,
)
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
@ -305,7 +336,7 @@ def main(model_list):
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad * scale_rate
X = X - FLAGS.step_lr * x_grad * scale_rate
X = tf.clip_by_value(X, 0, 1)
counter = counter + 1
@ -318,16 +349,16 @@ def main(model_list):
X_finals.append(X_final)
target_vars = {}
target_vars['X_START'] = X_START
target_vars['Y_GT'] = Y_GT
target_vars['X_finals'] = X_finals
target_vars['NOISE_SCALE'] = NOISE_SCALE
target_vars['energy_noise'] = energy_noise
target_vars["X_START"] = X_START
target_vars["Y_GT"] = Y_GT
target_vars["X_finals"] = X_finals
target_vars["NOISE_SCALE"] = NOISE_SCALE
target_vars["energy_noise"] = energy_noise
compute_inception(sess, target_vars)
if __name__ == "__main__":
# model_list = [117000, 116700]
model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
model_list = [FLAGS.resume_iter - 300 * i for i in range(FLAGS.ensemble)]
main(model_list)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff