mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-24 05:49:24 -04:00
chores: refactor for the new ai research, add linter, gh action, etc (#27)
This commit is contained in:
parent
fb4ab80dc3
commit
d5467e559f
40 changed files with 5177 additions and 2476 deletions
106
EBMs/README.md
106
EBMs/README.md
|
@ -1,12 +1,15 @@
|
|||
## quantum ai: training energy-based-models using openai
|
||||
## quantum ai: training energy-based-models using openAI
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
|
||||
|
||||
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in
|
||||
energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
### installing
|
||||
|
||||
<br>
|
||||
|
@ -19,7 +22,8 @@ brew install pkg-config
|
|||
|
||||
<br>
|
||||
|
||||
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with:
|
||||
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this
|
||||
problem (`PMIX ERROR: ERROR`) that can be fixed with:
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -40,7 +44,8 @@ pip install -r requirements.txt
|
|||
```
|
||||
<br>
|
||||
|
||||
* note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
|
||||
* note that this is an adapted requirement file since the **[openai's
|
||||
original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
|
||||
* finally, download and install **[mujoco](https://www.roboti.us/index.html)**
|
||||
* you will also need to register for a license, which asks for a machine ID
|
||||
* the documentation on the website is incomplete, so just download the suggested script and run:
|
||||
|
@ -64,7 +69,8 @@ mv getid_osx getid_osx.dms
|
|||
|
||||
<br>
|
||||
|
||||
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`:
|
||||
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder
|
||||
`cachedir`:
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -78,7 +84,8 @@ mkdir cachedir
|
|||
|
||||
<br>
|
||||
|
||||
* openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
|
||||
* openai's original code contains **[hardcoded constants that only work on
|
||||
Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
|
||||
* i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
|
||||
|
||||
<br>
|
||||
|
@ -87,7 +94,8 @@ mkdir cachedir
|
|||
|
||||
<br>
|
||||
|
||||
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command:
|
||||
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased
|
||||
substantially by using multiple different workers by running each command:
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -102,7 +110,8 @@ mpiexec -n <worker_num> <command>
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
|
||||
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
|
||||
--zero_kl --replay_batch --large_model
|
||||
```
|
||||
|
||||
* this should generate the following output:
|
||||
|
@ -112,7 +121,8 @@ python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_si
|
|||
```bash
|
||||
Instructions for updating:
|
||||
Use tf.gfile.GFile.
|
||||
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
||||
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is
|
||||
deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
||||
64 batch size
|
||||
Local rank: 0 1
|
||||
Loading data...
|
||||
|
@ -121,11 +131,15 @@ Files already downloaded and verified
|
|||
Files already downloaded and verified
|
||||
Files already downloaded and verified
|
||||
Done loading...
|
||||
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
|
||||
WARNING:tensorflow:From
|
||||
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263:
|
||||
colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
|
||||
Instructions for updating:
|
||||
Colocations handled automatically by placer.
|
||||
Building graph...
|
||||
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
|
||||
WARNING:tensorflow:From
|
||||
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from
|
||||
tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
|
||||
Instructions for updating:
|
||||
Use tf.cast instead.
|
||||
Finished processing loop construction ...
|
||||
|
@ -136,16 +150,36 @@ Model has a total of 7567880 parameters
|
|||
Initializing variables...
|
||||
Start broadcast
|
||||
End broadcast
|
||||
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l
|
||||
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/
|
||||
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268
|
||||
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0
|
||||
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2
|
||||
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908
|
||||
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463
|
||||
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g:
|
||||
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5
|
||||
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/
|
||||
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff:
|
||||
0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent:
|
||||
2.272536277770996, l
|
||||
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first:
|
||||
0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818,
|
||||
context_0/res_optim_res_c1/
|
||||
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10,
|
||||
context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10,
|
||||
context_0/res_optim_res_c2/gb:0: 6.27235652306268
|
||||
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11,
|
||||
context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437,
|
||||
context_0/res_1_res_c2/g:0
|
||||
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0:
|
||||
1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0:
|
||||
6.868505764145993e-10, context_0/res_2
|
||||
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0:
|
||||
8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0:
|
||||
0.0194275379180908
|
||||
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10,
|
||||
context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11,
|
||||
context_0/res_3_res_c2/gb:0: 7.75612463
|
||||
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0:
|
||||
4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0:
|
||||
0.34806951880455017, context_0/res_4_res_c2/g:
|
||||
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0:
|
||||
5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0:
|
||||
3.807749116013781e-10, context_0/res_5
|
||||
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0:
|
||||
7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039,
|
||||
context_0/fc5/
|
||||
b:0: 0.4506262540817261,
|
||||
|
||||
................................................................................................................................
|
||||
|
@ -159,7 +193,8 @@ Inception score of 1.2397289276123047 with std of 0.0
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
|
||||
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
|
||||
--zero_kl --replay_batch --cclass
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -169,7 +204,8 @@ python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
|
||||
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01
|
||||
--replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -179,7 +215,8 @@ python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=3
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
|
||||
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass
|
||||
--zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -197,7 +234,8 @@ python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0
|
|||
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
|
||||
```
|
||||
|
||||
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file
|
||||
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by
|
||||
different settings of task flag in the file
|
||||
* for example, to visualize cross class mappings in cifar-10, you can run:
|
||||
|
||||
<br>
|
||||
|
@ -217,7 +255,8 @@ python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resu
|
|||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
|
||||
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200
|
||||
--large_model --svhnmix --cclass=False
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -227,7 +266,8 @@ python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_
|
|||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd steps> --num_steps=10 --lival=<li bound value> --wider_model
|
||||
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd
|
||||
steps> --num_steps=10 --lival=<li bound value> --wider_model
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -236,12 +276,14 @@ python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=
|
|||
|
||||
<br>
|
||||
|
||||
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
|
||||
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in
|
||||
`cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
|
||||
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act
|
||||
--cond_pos --replay_batch -cclass
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -249,5 +291,7 @@ python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps
|
|||
* once models are trained, they can be sampled from jointly by running:
|
||||
|
||||
```
|
||||
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos> --exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot> --resume_pos=<resume_pos>
|
||||
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos>
|
||||
--exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot>
|
||||
--resume_pos=<resume_pos>
|
||||
```
|
||||
|
|
228
EBMs/ais.py
228
EBMs/ais.py
|
@ -1,40 +1,65 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from data import Cifar10, DSprites, Mnist
|
||||
from hmc import hmc
|
||||
from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
|
||||
from data import Cifar10, Mnist, DSprites
|
||||
from scipy.misc import logsumexp
|
||||
from scipy.misc import imsave
|
||||
from utils import optimistic_restore
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from utils import optimistic_restore
|
||||
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
||||
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
|
||||
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
|
||||
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
||||
flags.DEFINE_string(
|
||||
"dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_integer(
|
||||
"data_workers", 5, "Number of different data workers to load data in parallel"
|
||||
)
|
||||
flags.DEFINE_integer("batch_size", 16, "Size of inputs")
|
||||
flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
|
||||
|
||||
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
|
||||
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
|
||||
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
|
||||
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
|
||||
flags.DEFINE_bool(
|
||||
"max_pool",
|
||||
False,
|
||||
"Whether or not to use max pooling rather than strided convolutions",
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
"num_filters",
|
||||
64,
|
||||
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
||||
)
|
||||
flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
|
||||
flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
|
||||
flags.DEFINE_integer(
|
||||
"rescale", 1, "factor to rescale input outside of normal (0, 1) box"
|
||||
)
|
||||
flags.DEFINE_float(
|
||||
"temperature", 1, "temperature at which to compute likelihood of model"
|
||||
)
|
||||
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_bool(
|
||||
"cclass",
|
||||
False,
|
||||
"Whether to evaluate the log likelihood of conditional model or not",
|
||||
)
|
||||
flags.DEFINE_bool(
|
||||
"single",
|
||||
False,
|
||||
"Whether to evaluate the log likelihood of conditional model or not",
|
||||
)
|
||||
flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
|
||||
flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
|
||||
flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -45,11 +70,12 @@ label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
|||
def unscale_im(im):
|
||||
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
def gauss_prob_log(x, prec=1.0):
|
||||
|
||||
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
||||
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
|
||||
|
||||
return norm_constant_log + prob_density_log
|
||||
|
||||
|
@ -73,23 +99,36 @@ def model_prob_log(x, e_func, weights, temp):
|
|||
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
||||
|
||||
if FLAGS.dataset == "gauss":
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
|
||||
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
|
||||
x, prec=FLAGS.temperature
|
||||
)
|
||||
else:
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
|
||||
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
|
||||
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
|
||||
x, e_func, weights, temp
|
||||
)
|
||||
# Add an additional log likelihood penalty so that points outside of (0,
|
||||
# 1) box are *highly* unlikely
|
||||
|
||||
|
||||
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
|
||||
if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
|
||||
)
|
||||
elif FLAGS.dataset == "mnist":
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
|
||||
)
|
||||
else:
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
|
||||
axis=[1, 2, 3],
|
||||
)
|
||||
|
||||
return -norm_prob + oob_prob
|
||||
|
||||
|
||||
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
|
||||
def ancestral_sample(
|
||||
e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
|
||||
):
|
||||
if FLAGS.dataset == "2d":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 2))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
|
@ -130,41 +169,46 @@ def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_
|
|||
def main():
|
||||
|
||||
# Initialize dataset
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
if FLAGS.dataset == "cifar10":
|
||||
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 3
|
||||
dim_input = 32 * 32 * 3
|
||||
elif FLAGS.dataset == 'imagenet':
|
||||
32 * 32 * 3
|
||||
elif FLAGS.dataset == "imagenet":
|
||||
dataset = ImagenetClass()
|
||||
channel_num = 3
|
||||
dim_input = 64 * 64 * 3
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
64 * 64 * 3
|
||||
elif FLAGS.dataset == "mnist":
|
||||
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 1
|
||||
dim_input = 28 * 28 * 1
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
28 * 28 * 1
|
||||
elif FLAGS.dataset == "dsprites":
|
||||
dataset = DSprites()
|
||||
channel_num = 1
|
||||
dim_input = 64 * 64 * 1
|
||||
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
64 * 64 * 1
|
||||
elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
||||
dataset = Box2D()
|
||||
|
||||
dim_output = 1
|
||||
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=FLAGS.data_workers,
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
if FLAGS.dataset == 'mnist':
|
||||
if FLAGS.dataset == "mnist":
|
||||
model = MnistNet(num_channels=channel_num)
|
||||
elif FLAGS.dataset == 'cifar10':
|
||||
elif FLAGS.dataset == "cifar10":
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=128)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(num_filters=192)
|
||||
else:
|
||||
model = ResNet32(num_channels=channel_num, num_filters=128)
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
elif FLAGS.dataset == "dsprites":
|
||||
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
||||
|
||||
weights = model.construct_weights('context_{}'.format(0))
|
||||
weights = model.construct_weights("context_{}".format(0))
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.Session(config=config)
|
||||
|
@ -173,8 +217,8 @@ def main():
|
|||
sess.run(tf.global_variables_initializer())
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
|
||||
FLAGS.resume_iter
|
||||
|
||||
if FLAGS.resume_iter != "-1":
|
||||
optimistic_restore(sess, model_file)
|
||||
|
@ -182,14 +226,17 @@ def main():
|
|||
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
||||
# saver.restore(sess, model_file)
|
||||
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
|
||||
model, weights, FLAGS.batch_size, temp=FLAGS.temperature
|
||||
)
|
||||
print("Finished constructing ancestral sample ...................")
|
||||
|
||||
if FLAGS.dataset != "gauss":
|
||||
comb_weights_cum = []
|
||||
batch_size = tf.shape(x_init)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
|
||||
e_compute = -FLAGS.temperature * model.forward(
|
||||
x_init, weights, label=label_tiled
|
||||
)
|
||||
e_pos_list = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(data_loader):
|
||||
|
@ -205,44 +252,75 @@ def main():
|
|||
alr = 0.0085
|
||||
elif FLAGS.dataset == "mnist":
|
||||
alr = 0.0065
|
||||
#90 alr = 0.0035
|
||||
# 90 alr = 0.0035
|
||||
else:
|
||||
# alr = 0.0125
|
||||
if FLAGS.rescale == 8:
|
||||
alr = 0.0085
|
||||
else:
|
||||
alr = 0.0045
|
||||
#
|
||||
#
|
||||
for i in range(1):
|
||||
tot_weight = 0
|
||||
for j in tqdm(range(1, FLAGS.pdist+1)):
|
||||
for j in tqdm(range(1, FLAGS.pdist + 1)):
|
||||
if j == 1:
|
||||
if FLAGS.dataset == "cifar10":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
|
||||
)
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
|
||||
)
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
|
||||
)
|
||||
else:
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
|
||||
)
|
||||
|
||||
alpha_prev = (j-1) / FLAGS.pdist
|
||||
alpha_prev = (j - 1) / FLAGS.pdist
|
||||
alpha_new = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
cweight, x_curr = sess.run(
|
||||
[chain_weights, x],
|
||||
{
|
||||
a_prev: alpha_prev,
|
||||
a_new: alpha_new,
|
||||
x_init: x_curr,
|
||||
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
||||
},
|
||||
)
|
||||
tot_weight = tot_weight + cweight
|
||||
|
||||
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
print(
|
||||
"Total values of lower value based off forward sampling",
|
||||
np.mean(tot_weight),
|
||||
np.std(tot_weight),
|
||||
)
|
||||
|
||||
tot_weight = 0
|
||||
|
||||
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
||||
alpha_new = (j-1) / FLAGS.pdist
|
||||
alpha_new = (j - 1) / FLAGS.pdist
|
||||
alpha_prev = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
cweight, x_curr = sess.run(
|
||||
[chain_weights, x],
|
||||
{
|
||||
a_prev: alpha_prev,
|
||||
a_new: alpha_new,
|
||||
x_init: x_curr,
|
||||
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
||||
},
|
||||
)
|
||||
tot_weight = tot_weight - cweight
|
||||
|
||||
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
print(
|
||||
"Total values of upper value based off backward sampling",
|
||||
np.mean(tot_weight),
|
||||
np.std(tot_weight),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -14,223 +14,244 @@
|
|||
# ==============================================================================
|
||||
|
||||
"""Adam for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.ops import (
|
||||
control_flow_ops,
|
||||
math_ops,
|
||||
resource_variable_ops,
|
||||
state_ops,
|
||||
)
|
||||
from tensorflow.python.training import optimizer, training_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@tf_export("train.AdamOptimizer")
|
||||
class AdamOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
|
||||
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
|
||||
use_locking=False, name="Adam"):
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
Initialization:
|
||||
|
||||
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||
$$t := 0 \text{(Initialize timestep)}$$
|
||||
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section2 of the paper:
|
||||
|
||||
$$t := t + 1$$
|
||||
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
|
||||
The default value of 1e-8 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
"""
|
||||
super(AdamOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._beta1 = beta1
|
||||
self._beta2 = beta2
|
||||
self._epsilon = epsilon
|
||||
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._beta1_t = None
|
||||
self._beta2_t = None
|
||||
self._epsilon_t = None
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
use_locking=False,
|
||||
name="Adam",
|
||||
):
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
# Created in SparseApply if needed.
|
||||
self._updated_lr = None
|
||||
Initialization:
|
||||
|
||||
def _get_beta_accumulators(self):
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
graph = None
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
return (self._get_non_slot_variable("beta1_power", graph=graph),
|
||||
self._get_non_slot_variable("beta2_power", graph=graph))
|
||||
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||
$$t := 0 \text{(Initialize timestep)}$$
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
# Create the beta1 and beta2 accumulators on the same device as the first
|
||||
# variable. Sort the var_list to make sure this device is consistent across
|
||||
# workers (these need to go on the same PS, otherwise some updates are
|
||||
# silently ignored).
|
||||
first_var = min(var_list, key=lambda x: x.name)
|
||||
self._create_non_slot_variable(initial_value=self._beta1,
|
||||
name="beta1_power",
|
||||
colocate_with=first_var)
|
||||
self._create_non_slot_variable(initial_value=self._beta2,
|
||||
name="beta2_power",
|
||||
colocate_with=first_var)
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section2 of the paper:
|
||||
|
||||
# Create slots for the first and second moments.
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "m", self._name)
|
||||
self._zeros_slot(v, "v", self._name)
|
||||
$$t := t + 1$$
|
||||
$$lr_t := \text{learning\\_rate} * \\sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
|
||||
def _prepare(self):
|
||||
lr = self._call_if_callable(self._lr)
|
||||
beta1 = self._call_if_callable(self._beta1)
|
||||
beta2 = self._call_if_callable(self._beta2)
|
||||
epsilon = self._call_if_callable(self._epsilon)
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\\sqrt{v_t} + \\epsilon)$$
|
||||
|
||||
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
|
||||
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
|
||||
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
|
||||
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
|
||||
The default value of 1e-8 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
|
||||
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
|
||||
# Clip gradients by 3 std
|
||||
return training_ops.apply_adam(
|
||||
var, m, v,
|
||||
math_ops.cast(beta1_power, var.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, var.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking).op
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.resource_apply_adam(
|
||||
var.handle, m.handle, v.handle,
|
||||
math_ops.cast(beta1_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking)
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(AdamOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._beta1 = beta1
|
||||
self._beta2 = beta2
|
||||
self._epsilon = epsilon
|
||||
|
||||
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
||||
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||
m = self.get_slot(var, "m")
|
||||
m_scaled_g_values = grad * (1 - beta1_t)
|
||||
m_t = state_ops.assign(m, m * beta1_t,
|
||||
use_locking=self._use_locking)
|
||||
with ops.control_dependencies([m_t]):
|
||||
m_t = scatter_add(m, indices, m_scaled_g_values)
|
||||
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||
v = self.get_slot(var, "v")
|
||||
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
||||
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([v_t]):
|
||||
v_t = scatter_add(v, indices, v_scaled_g_values)
|
||||
v_sqrt = math_ops.sqrt(v_t)
|
||||
var_update = state_ops.assign_sub(var,
|
||||
lr * m_t / (v_sqrt + epsilon_t),
|
||||
use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*[var_update, m_t, v_t])
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._beta1_t = None
|
||||
self._beta2_t = None
|
||||
self._epsilon_t = None
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
return self._apply_sparse_shared(
|
||||
grad.values, var, grad.indices,
|
||||
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
|
||||
x, i, v, use_locking=self._use_locking))
|
||||
# Created in SparseApply if needed.
|
||||
self._updated_lr = None
|
||||
|
||||
def _resource_scatter_add(self, x, i, v):
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(
|
||||
x.handle, i, v)]):
|
||||
return x.value()
|
||||
def _get_beta_accumulators(self):
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
graph = None
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
return (
|
||||
self._get_non_slot_variable("beta1_power", graph=graph),
|
||||
self._get_non_slot_variable("beta2_power", graph=graph),
|
||||
)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
return self._apply_sparse_shared(
|
||||
grad, var, indices, self._resource_scatter_add)
|
||||
def _create_slots(self, var_list):
|
||||
# Create the beta1 and beta2 accumulators on the same device as the first
|
||||
# variable. Sort the var_list to make sure this device is consistent across
|
||||
# workers (these need to go on the same PS, otherwise some updates are
|
||||
# silently ignored).
|
||||
first_var = min(var_list, key=lambda x: x.name)
|
||||
self._create_non_slot_variable(
|
||||
initial_value=self._beta1, name="beta1_power", colocate_with=first_var
|
||||
)
|
||||
self._create_non_slot_variable(
|
||||
initial_value=self._beta2, name="beta2_power", colocate_with=first_var
|
||||
)
|
||||
|
||||
def _finish(self, update_ops, name_scope):
|
||||
# Update the power accumulators.
|
||||
with ops.control_dependencies(update_ops):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
with ops.colocate_with(beta1_power):
|
||||
update_beta1 = beta1_power.assign(
|
||||
beta1_power * self._beta1_t, use_locking=self._use_locking)
|
||||
update_beta2 = beta2_power.assign(
|
||||
beta2_power * self._beta2_t, use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
|
||||
name=name_scope)
|
||||
# Create slots for the first and second moments.
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "m", self._name)
|
||||
self._zeros_slot(v, "v", self._name)
|
||||
|
||||
def _prepare(self):
|
||||
lr = self._call_if_callable(self._lr)
|
||||
beta1 = self._call_if_callable(self._beta1)
|
||||
beta2 = self._call_if_callable(self._beta2)
|
||||
epsilon = self._call_if_callable(self._epsilon)
|
||||
|
||||
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
|
||||
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
|
||||
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
|
||||
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
|
||||
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
|
||||
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
|
||||
# Clip gradients by 3 std
|
||||
return training_ops.apply_adam(
|
||||
var,
|
||||
m,
|
||||
v,
|
||||
math_ops.cast(beta1_power, var.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, var.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
|
||||
grad,
|
||||
use_locking=self._use_locking,
|
||||
).op
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.resource_apply_adam(
|
||||
var.handle,
|
||||
m.handle,
|
||||
v.handle,
|
||||
math_ops.cast(beta1_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
|
||||
grad,
|
||||
use_locking=self._use_locking,
|
||||
)
|
||||
|
||||
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||
m = self.get_slot(var, "m")
|
||||
m_scaled_g_values = grad * (1 - beta1_t)
|
||||
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([m_t]):
|
||||
m_t = scatter_add(m, indices, m_scaled_g_values)
|
||||
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||
v = self.get_slot(var, "v")
|
||||
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
||||
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([v_t]):
|
||||
v_t = scatter_add(v, indices, v_scaled_g_values)
|
||||
v_sqrt = math_ops.sqrt(v_t)
|
||||
var_update = state_ops.assign_sub(
|
||||
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking
|
||||
)
|
||||
return control_flow_ops.group(*[var_update, m_t, v_t])
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
return self._apply_sparse_shared(
|
||||
grad.values,
|
||||
var,
|
||||
grad.indices,
|
||||
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
|
||||
x, i, v, use_locking=self._use_locking
|
||||
),
|
||||
)
|
||||
|
||||
def _resource_scatter_add(self, x, i, v):
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]
|
||||
):
|
||||
return x.value()
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
|
||||
|
||||
def _finish(self, update_ops, name_scope):
|
||||
# Update the power accumulators.
|
||||
with ops.control_dependencies(update_ops):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
with ops.colocate_with(beta1_power):
|
||||
update_beta1 = beta1_power.assign(
|
||||
beta1_power * self._beta1_t, use_locking=self._use_locking
|
||||
)
|
||||
update_beta2 = beta2_power.assign(
|
||||
beta2_power * self._beta2_t, use_locking=self._use_locking
|
||||
)
|
||||
return control_flow_ops.group(
|
||||
*update_ops + [update_beta1, update_beta2], name=name_scope
|
||||
)
|
||||
|
|
311
EBMs/data.py
311
EBMs/data.py
|
@ -1,42 +1,48 @@
|
|||
from tensorflow.python.platform import flags
|
||||
from tensorflow.contrib.data.python.ops import batching, threadpool
|
||||
import tensorflow as tf
|
||||
import json
|
||||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import os
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import time
|
||||
from scipy.misc import imread, imresize
|
||||
from skimage.color import rgb2grey
|
||||
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
|
||||
from torchvision import transforms
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torchvision
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
from scipy.misc import imread, imresize
|
||||
from tensorflow.contrib.data.python.ops import batching
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DIR = "./results"
|
||||
|
||||
# Dataset Options
|
||||
flags.DEFINE_string('dsprites_path',
|
||||
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
|
||||
'path to dsprites characters')
|
||||
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
|
||||
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
|
||||
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
|
||||
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
|
||||
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
|
||||
flags.DEFINE_string(
|
||||
"dsprites_path",
|
||||
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
||||
"path to dsprites characters",
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
"imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
|
||||
)
|
||||
flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
|
||||
flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
|
||||
flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
|
||||
flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
|
||||
flags.DEFINE_bool(
|
||||
"dsprites_restrict", False, "fix all factors except for rotation of objects"
|
||||
)
|
||||
flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
|
||||
|
||||
|
||||
# Data augmentation options
|
||||
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
|
||||
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
|
||||
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
|
||||
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
|
||||
flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
|
||||
flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
|
||||
flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
|
||||
flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
|
||||
|
||||
|
||||
def cutout(mask_color=(0, 0, 0)):
|
||||
|
@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
|
|||
|
||||
self.curr_sample = 0
|
||||
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
|
||||
with open(index_path) as f:
|
||||
metadata = json.load(f)
|
||||
counts = metadata['record_counts']
|
||||
counts = metadata["record_counts"]
|
||||
|
||||
if split == 'train':
|
||||
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
|
||||
if split == "train":
|
||||
file_names = list(
|
||||
sorted([x for x in counts.keys() if x.startswith("train")])
|
||||
)
|
||||
|
||||
result_records_to_skip = None
|
||||
files = []
|
||||
|
@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
|
|||
# Record the number to skip in the first file
|
||||
result_records_to_skip = records_to_skip
|
||||
files.append(filename)
|
||||
records_to_read -= (records_in_file - records_to_skip)
|
||||
records_to_read -= records_in_file - records_to_skip
|
||||
records_to_skip = 0
|
||||
else:
|
||||
break
|
||||
else:
|
||||
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
|
||||
files = list(
|
||||
sorted([x for x in counts.keys() if x.startswith("validation")])
|
||||
)
|
||||
|
||||
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
|
||||
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
|
||||
preprocess_function = ImagenetPreprocessor(
|
||||
128, dtype=tf.float32, train=False
|
||||
).parse_and_preprocess
|
||||
|
||||
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
|
||||
ds = tf.data.TFRecordDataset.from_generator(
|
||||
lambda: files, output_types=tf.string
|
||||
)
|
||||
ds = ds.apply(tf.data.TFRecordDataset)
|
||||
ds = ds.take(im_length)
|
||||
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
|
||||
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
|
||||
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
|
||||
ds = ds.apply(
|
||||
batching.map_and_batch(
|
||||
map_func=preprocess_function,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_parallel_batches=4,
|
||||
)
|
||||
)
|
||||
ds = ds.prefetch(buffer_size=2)
|
||||
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
labels, images = ds_iterator.get_next()
|
||||
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
|
||||
self.images = tf.clip_by_value(
|
||||
images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
|
||||
)
|
||||
self.labels = labels
|
||||
|
||||
config = tf.ConfigProto(device_count = {'GPU': 0})
|
||||
config = tf.ConfigProto(device_count={"GPU": 0})
|
||||
sess = tf.Session(config=config)
|
||||
sess.run(ds_iterator.initializer)
|
||||
|
||||
|
@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
|
|||
|
||||
sess = self.sess
|
||||
|
||||
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
im_corrupt = np.random.uniform(
|
||||
0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
|
||||
)
|
||||
label, im = sess.run([self.labels, self.images])
|
||||
im = im * self.rescale
|
||||
label = np.eye(1000)[label.squeeze() - 1]
|
||||
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
|
||||
im, im_corrupt, label = (
|
||||
torch.from_numpy(im),
|
||||
torch.from_numpy(im_corrupt),
|
||||
torch.from_numpy(label),
|
||||
)
|
||||
return im_corrupt, im, label
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
|
|||
def __len__(self):
|
||||
return self.im_length
|
||||
|
||||
|
||||
class CelebA(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -180,25 +209,18 @@ class CelebA(Dataset):
|
|||
im = imread(path)
|
||||
im = imresize(im, (32, 32))
|
||||
image_size = 32
|
||||
im = im / 255.
|
||||
im = im / 255.0
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0, 1, size=(image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar10(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
train=True,
|
||||
full=False,
|
||||
augment=False,
|
||||
noise=True,
|
||||
rescale=1.0):
|
||||
def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
|
@ -215,16 +237,10 @@ class Cifar10(Dataset):
|
|||
transform = transforms.ToTensor()
|
||||
|
||||
self.full = full
|
||||
self.data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
|
||||
self.test_data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=False,
|
||||
download=True)
|
||||
ROOT_DIR, transform=transform, train=False, download=True
|
||||
)
|
||||
self.one_hot_map = np.eye(10)
|
||||
self.noise = noise
|
||||
self.rescale = rescale
|
||||
|
@ -255,16 +271,18 @@ class Cifar10(Dataset):
|
|||
im = im * 255 / 256
|
||||
|
||||
if self.noise:
|
||||
im = im * self.rescale + \
|
||||
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
|
||||
im = im * self.rescale + np.random.uniform(
|
||||
0, self.rescale * 1 / 256.0, im.shape
|
||||
)
|
||||
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, self.rescale, (image_size, image_size, 3))
|
||||
0.0, self.rescale, (image_size, image_size, 3)
|
||||
)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -287,10 +305,8 @@ class Cifar100(Dataset):
|
|||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = CIFAR100(
|
||||
"/root/cifar100",
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
"/root/cifar100", transform=transform, train=train, download=True
|
||||
)
|
||||
self.one_hot_map = np.eye(100)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -308,11 +324,10 @@ class Cifar100(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -340,11 +355,10 @@ class Svhn(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -352,9 +366,8 @@ class Svhn(Dataset):
|
|||
class Mnist(Dataset):
|
||||
def __init__(self, train=True, rescale=1.0):
|
||||
self.data = MNIST(
|
||||
"/root/mnist",
|
||||
transform=transforms.ToTensor(),
|
||||
download=True, train=train)
|
||||
"/root/mnist", transform=transforms.ToTensor(), download=True, train=train
|
||||
)
|
||||
self.labels = np.eye(10)
|
||||
self.rescale = rescale
|
||||
|
||||
|
@ -367,13 +380,13 @@ class Mnist(Dataset):
|
|||
im = im.squeeze()
|
||||
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
|
||||
# im = im.numpy() / 2 + 0.2
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
|
||||
im = im * self.rescale
|
||||
image_size = 28
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
@ -381,54 +394,63 @@ class Mnist(Dataset):
|
|||
|
||||
class DSprites(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cond_size=False,
|
||||
cond_shape=False,
|
||||
cond_pos=False,
|
||||
cond_rot=False):
|
||||
self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
|
||||
):
|
||||
dat = np.load(FLAGS.dsprites_path)
|
||||
|
||||
if FLAGS.dshape_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
|
||||
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
l = dat["latents_values"]
|
||||
mask = (
|
||||
(l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
& (l[:, 3] == 30 * np.pi / 39)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
|
||||
self.label = self.label[:, 1:2]
|
||||
elif FLAGS.dpos_only:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 1] == 1) & (
|
||||
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (100, 1))
|
||||
self.label = self.label[:, 4:] + 0.5
|
||||
elif FLAGS.dsize_only:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = (self.label[:, 2:3])
|
||||
mask = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
|
||||
self.label = self.label[:, 2:3]
|
||||
elif FLAGS.drot_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = (self.label[:, 3:4])
|
||||
l = dat["latents_values"]
|
||||
mask = (
|
||||
(l[:, 2] == 0.5)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (100, 1))
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1
|
||||
)
|
||||
elif FLAGS.dsprites_restrict:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
|
||||
|
||||
self.data = dat['imgs'][mask]
|
||||
self.label = dat['latents_values'][mask]
|
||||
self.data = dat["imgs"][mask]
|
||||
self.label = dat["latents_values"][mask]
|
||||
else:
|
||||
self.data = dat['imgs']
|
||||
self.label = dat['latents_values']
|
||||
self.data = dat["imgs"]
|
||||
self.label = dat["latents_values"]
|
||||
|
||||
if cond_size:
|
||||
self.label = self.label[:, 2:3]
|
||||
|
@ -439,7 +461,8 @@ class DSprites(Dataset):
|
|||
elif cond_rot:
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1
|
||||
)
|
||||
else:
|
||||
self.label = self.label[:, 1:2]
|
||||
|
||||
|
@ -452,20 +475,20 @@ class DSprites(Dataset):
|
|||
im = self.data[index]
|
||||
image_size = 64
|
||||
|
||||
if not (
|
||||
FLAGS.dpos_only or FLAGS.dsize_only) and (
|
||||
not FLAGS.cond_size) and (
|
||||
not FLAGS.cond_pos) and (
|
||||
not FLAGS.cond_rot) and (
|
||||
not FLAGS.drot_only):
|
||||
label = self.identity[self.label[index].astype(
|
||||
np.int32) - 1].squeeze()
|
||||
if (
|
||||
not (FLAGS.dpos_only or FLAGS.dsize_only)
|
||||
and (not FLAGS.cond_size)
|
||||
and (not FLAGS.cond_pos)
|
||||
and (not FLAGS.cond_rot)
|
||||
and (not FLAGS.drot_only)
|
||||
):
|
||||
label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
|
||||
else:
|
||||
label = self.label[index]
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
@ -478,25 +501,20 @@ class Imagenet(Dataset):
|
|||
for i in range(1, 11):
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'train_data_batch_{}'.format(i)),
|
||||
'rb'))
|
||||
osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
|
||||
"rb",
|
||||
)
|
||||
)
|
||||
if i == 1:
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
labels = f["labels"]
|
||||
data = f["data"]
|
||||
else:
|
||||
labels.extend(f['labels'])
|
||||
data = np.vstack((data, f['data']))
|
||||
labels.extend(f["labels"])
|
||||
data = np.vstack((data, f["data"]))
|
||||
else:
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'val_data'),
|
||||
'rb'))
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
|
||||
labels = f["labels"]
|
||||
data = f["data"]
|
||||
|
||||
self.labels = labels
|
||||
self.data = data
|
||||
|
@ -520,11 +538,10 @@ class Imagenet(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
|
|
@ -1,68 +1,100 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from hmc import hmc
|
||||
from custom_adam import AdamOptimizer
|
||||
from models import DspritesNet
|
||||
from scipy.misc import imsave
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from models import DspritesNet
|
||||
from utils import optimistic_restore, ReplayBuffer
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from rl_algs.logger import TensorBoardOutputFormat
|
||||
from scipy.misc import imsave
|
||||
import os
|
||||
from custom_adam import AdamOptimizer
|
||||
from tqdm import tqdm
|
||||
from utils import ReplayBuffer
|
||||
|
||||
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
|
||||
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
|
||||
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
|
||||
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
|
||||
flags.DEFINE_bool('cclass', True, 'not cclass')
|
||||
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
|
||||
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_integer("batch_size", 256, "Size of inputs")
|
||||
flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
|
||||
flags.DEFINE_string("logdir", "cachedir", "directory for logging")
|
||||
flags.DEFINE_string(
|
||||
"savedir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
"num_filters",
|
||||
64,
|
||||
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
||||
)
|
||||
flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
|
||||
flags.DEFINE_string(
|
||||
"dsprites_path",
|
||||
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
||||
"path to dsprites characters",
|
||||
)
|
||||
flags.DEFINE_bool("cclass", True, "not cclass")
|
||||
flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
|
||||
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
||||
flags.DEFINE_string(
|
||||
"task",
|
||||
"conceptcombine",
|
||||
"conceptcombine, labeldiscover, gentest, genbaseline, etc.",
|
||||
)
|
||||
flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
|
||||
flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
|
||||
|
||||
# Conditions on which models to use
|
||||
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
|
||||
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
|
||||
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
|
||||
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
|
||||
flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
|
||||
flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
|
||||
flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
|
||||
flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
|
||||
|
||||
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
|
||||
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
|
||||
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
|
||||
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
|
||||
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('break_steps', 300, 'steps to break')
|
||||
flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
|
||||
flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
|
||||
flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
|
||||
flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
|
||||
flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
|
||||
flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("break_steps", 300, "steps to break")
|
||||
|
||||
# Whether to train for gentest
|
||||
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
|
||||
flags.DEFINE_bool(
|
||||
"train",
|
||||
False,
|
||||
"whether to train on generalization into multiple different predictions",
|
||||
)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class DSpritesGen(Dataset):
|
||||
def __init__(self, data, latents, frac=0.0):
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
mask_size = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
)
|
||||
elif FLAGS.joint_rot:
|
||||
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
mask_size = (
|
||||
(l[:, 1] == 1)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
)
|
||||
else:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
|
||||
mask_size = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
|
||||
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
|
||||
|
@ -80,12 +112,14 @@ class DSpritesGen(Dataset):
|
|||
self.data = np.concatenate((data_pos, data_size), axis=0)
|
||||
self.label = np.concatenate((l_pos, l_size), axis=0)
|
||||
|
||||
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
|
||||
mask_neg = (~(mask_size & mask_pos)) & (
|
||||
(l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
|
||||
)
|
||||
data_add = data[mask_neg]
|
||||
l_add = l[mask_neg]
|
||||
|
||||
perm_idx = np.random.permutation(data_add.shape[0])
|
||||
select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
|
||||
select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
|
||||
data_add = data_add[select_idx]
|
||||
l_add = l_add[select_idx]
|
||||
|
||||
|
@ -104,7 +138,9 @@ class DSpritesGen(Dataset):
|
|||
if FLAGS.joint_shape:
|
||||
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
|
||||
elif FLAGS.joint_rot:
|
||||
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
|
||||
label_size = np.array(
|
||||
[np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
|
||||
)
|
||||
else:
|
||||
label_size = self.label[index, 2:3]
|
||||
|
||||
|
@ -114,14 +150,16 @@ class DSpritesGen(Dataset):
|
|||
|
||||
|
||||
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
model_size = kvs['model_size']
|
||||
weight_size = kvs['weight_size']
|
||||
x_mod = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
model_size = kvs["model_size"]
|
||||
weight_size = kvs["weight_size"]
|
||||
x_mod = kvs["X_NOISE"]
|
||||
|
||||
label_output = LABEL_SIZE
|
||||
for i in range(FLAGS.num_steps):
|
||||
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
|
||||
label_output = label_output + tf.random_normal(
|
||||
tf.shape(label_output), mean=0.0, stddev=0.03
|
||||
)
|
||||
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
|
||||
label_grad = tf.gradients(e_noise, [label_output])[0]
|
||||
# label_grad = tf.Print(label_grad, [label_grad])
|
||||
|
@ -130,13 +168,13 @@ def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
|||
|
||||
diffs = []
|
||||
for i in range(30):
|
||||
s = i*FLAGS.batch_size
|
||||
d = (i+1)*FLAGS.batch_size
|
||||
s = i * FLAGS.batch_size
|
||||
d = (i + 1) * FLAGS.batch_size
|
||||
data_i = data[s:d]
|
||||
latent_i = latents[s:d]
|
||||
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
|
||||
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
|
||||
size_pred = sess.run([label_output], feed_dict)[0]
|
||||
size_gt = latent_i[:, 2:3]
|
||||
|
||||
|
@ -155,9 +193,11 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
|
||||
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
|
||||
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
|
||||
weights_baseline = model_baseline.construct_weights(
|
||||
"context_baseline_{}".format(frac)
|
||||
)
|
||||
|
||||
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
|
||||
X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
|
||||
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
|
||||
|
@ -168,14 +208,20 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
||||
train_op = optimizer.apply_gradients(gvs)
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
dataloader = DataLoader(
|
||||
DSpritesGen(data, latents, frac=frac),
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=6,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
datafull = data
|
||||
|
||||
itr = 0
|
||||
saver = tf.train.Saver()
|
||||
|
||||
vs = optimizer.variables()
|
||||
optimizer.variables()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.train:
|
||||
|
@ -185,7 +231,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
data_corrupt = data_corrupt.numpy()
|
||||
label_size, label_pos = label_size.numpy(), label_pos.numpy()
|
||||
|
||||
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
|
||||
data_corrupt = np.random.randn(
|
||||
data_corrupt.shape[0], 2 * FLAGS.num_filters
|
||||
)
|
||||
label_comb = np.concatenate([label_size, label_pos], axis=1)
|
||||
|
||||
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
|
||||
|
@ -196,23 +244,27 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
|
||||
itr += 1
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
||||
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
||||
)
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
|
||||
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
|
||||
for dat, latent in zip(
|
||||
np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
|
||||
):
|
||||
data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
|
||||
|
@ -220,16 +272,19 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
latent = np.concatenate([latent_size, latent_pos], axis=1)
|
||||
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
|
||||
else:
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
|
||||
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
|
||||
# print(loss)
|
||||
losses.append(loss)
|
||||
|
||||
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
|
||||
|
||||
print(
|
||||
"Overall MSE for generalization of {} for fraction of {}".format(
|
||||
np.mean(losses), frac
|
||||
)
|
||||
)
|
||||
|
||||
data_try = data_gen[:10]
|
||||
data_init = np.random.randn(10, 2*FLAGS.num_filters)
|
||||
data_init = np.random.randn(10, 2 * FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
|
||||
|
@ -252,7 +307,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
||||
-1, 66 * 2
|
||||
)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
@ -261,38 +318,40 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
|
||||
|
||||
def gentest(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X_NOISE = kvs["X_NOISE"]
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
||||
LABEL_POS = kvs["LABEL_POS"]
|
||||
LABEL_ROT = kvs["LABEL_ROT"]
|
||||
model_size = kvs["model_size"]
|
||||
model_shape = kvs["model_shape"]
|
||||
model_pos = kvs["model_pos"]
|
||||
model_rot = kvs["model_rot"]
|
||||
weight_size = kvs["weight_size"]
|
||||
weight_shape = kvs["weight_shape"]
|
||||
weight_pos = kvs["weight_pos"]
|
||||
weight_rot = kvs["weight_rot"]
|
||||
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
datafull = data
|
||||
# Test combination of generalization where we use slices of both training
|
||||
x_final = X_NOISE
|
||||
x_mod_size = X_NOISE
|
||||
x_mod_pos = X_NOISE
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
||||
# use cond_pos
|
||||
|
||||
energies = []
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(
|
||||
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
||||
)
|
||||
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
|
||||
|
||||
# energies.append(e_noise)
|
||||
x_grad = tf.gradients(e_noise, [x_final])[0]
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(
|
||||
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
||||
)
|
||||
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
||||
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
||||
|
||||
|
@ -332,42 +391,70 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
x_mod = x_mod_pos
|
||||
x_final = x_mod
|
||||
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_shape.forward(
|
||||
x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_shape.forward(
|
||||
X, weight_shape, reuse=True, label=LABEL_SHAPE
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_shape.forward(
|
||||
tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
elif FLAGS.joint_rot:
|
||||
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_rot.forward(
|
||||
x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_rot.forward(
|
||||
X, weight_rot, reuse=True, label=LABEL_ROT
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_rot.forward(
|
||||
tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
else:
|
||||
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_size.forward(
|
||||
x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_size.forward(
|
||||
X, weight_size, reuse=True, label=LABEL_SIZE
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_size.forward(
|
||||
tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
|
||||
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
|
||||
energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
|
||||
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
|
||||
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
||||
neg_loss = coeff * (-1*energy_neg) / norm_constant
|
||||
coeff * (-1 * energy_neg) / norm_constant
|
||||
|
||||
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
|
||||
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
|
||||
loss_total = (
|
||||
loss_ml
|
||||
+ tf.reduce_mean(loss_kl)
|
||||
+ 1
|
||||
* (
|
||||
tf.reduce_mean(tf.square(energy_pos))
|
||||
+ tf.reduce_mean(tf.square(energy_neg))
|
||||
)
|
||||
)
|
||||
|
||||
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
|
||||
gvs = optimizer.compute_gradients(loss_total)
|
||||
|
@ -377,7 +464,13 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
vs = optimizer.variables()
|
||||
sess.run(tf.variables_initializer(vs))
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
dataloader = DataLoader(
|
||||
DSpritesGen(data, latents),
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=6,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
x_off = tf.reduce_mean(tf.square(x_mod - X))
|
||||
|
||||
|
@ -385,12 +478,10 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
saver = tf.train.Saver()
|
||||
x_mod = None
|
||||
|
||||
|
||||
if FLAGS.train:
|
||||
replay_buffer = ReplayBuffer(10000)
|
||||
for _ in range(1):
|
||||
|
||||
|
||||
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
||||
data_corrupt = data_corrupt.numpy()[:, :, :]
|
||||
data = data.numpy()[:, :, :]
|
||||
|
@ -398,29 +489,50 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
if x_mod is not None:
|
||||
replay_buffer.add(x_mod)
|
||||
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
||||
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
|
||||
replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
|
||||
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_SHAPE: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_ROT: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_SIZE: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run(
|
||||
[train_op, x_off, energy_pos, energy_neg, x_final],
|
||||
feed_dict=feed_dict,
|
||||
)
|
||||
itr += 1
|
||||
|
||||
if itr % 10 == 0:
|
||||
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
|
||||
print(
|
||||
"x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
|
||||
off_value, e_pos.mean(), e_neg.mean(), itr
|
||||
)
|
||||
)
|
||||
|
||||
if itr == FLAGS.break_steps:
|
||||
break
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
|
||||
|
||||
l = latents
|
||||
|
||||
|
@ -429,22 +541,43 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
elif FLAGS.joint_rot:
|
||||
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
||||
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
||||
)
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
|
||||
for dat, latent in zip(
|
||||
np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
|
||||
):
|
||||
x = 0.5 + np.random.randn(*dat.shape)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_ROT: np.concatenate(
|
||||
[np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
|
||||
),
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
else:
|
||||
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_SIZE: latent[:, 2:3],
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
|
||||
for i in range(2):
|
||||
x = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
@ -461,11 +594,25 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
latent_pos = latents_gen[:10, 4:]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
|
||||
LABEL_POS: latent_pos,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
|
||||
feed_dict = {
|
||||
LABEL_ROT: np.concatenate(
|
||||
[np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
|
||||
),
|
||||
LABEL_POS: latent[:10, 4:],
|
||||
X_NOISE: data_init,
|
||||
}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SIZE: latent_scale,
|
||||
LABEL_POS: latent_pos,
|
||||
}
|
||||
|
||||
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
||||
|
@ -480,27 +627,28 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
||||
-1, 66 * 2
|
||||
)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
|
||||
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X_NOISE = kvs["X_NOISE"]
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
||||
LABEL_POS = kvs["LABEL_POS"]
|
||||
LABEL_ROT = kvs["LABEL_ROT"]
|
||||
model_size = kvs["model_size"]
|
||||
model_shape = kvs["model_shape"]
|
||||
model_pos = kvs["model_pos"]
|
||||
model_rot = kvs["model_rot"]
|
||||
weight_size = kvs["weight_size"]
|
||||
weight_shape = kvs["weight_shape"]
|
||||
weight_pos = kvs["weight_pos"]
|
||||
weight_rot = kvs["weight_rot"]
|
||||
|
||||
x_mod = X_NOISE
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
@ -540,13 +688,18 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|||
data_try = data[:10]
|
||||
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
||||
label_scale = latents[:10, 2:3]
|
||||
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
|
||||
label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
|
||||
label_rot = latents[:10, 3:4]
|
||||
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
|
||||
label_pos = latents[:10, 4:]
|
||||
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SIZE: label_scale,
|
||||
LABEL_SHAPE: label_shape,
|
||||
LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot,
|
||||
}
|
||||
x_out = sess.run([x_final], feed_dict)[0]
|
||||
|
||||
im_name = "im"
|
||||
|
@ -569,14 +722,15 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|||
x_out_pad[:, 1:-1, 1:-1] = x_out
|
||||
data_try_pad[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
def main():
|
||||
data = np.load(FLAGS.dsprites_path)['imgs']
|
||||
l = latents = np.load(FLAGS.dsprites_path)['latents_values']
|
||||
data = np.load(FLAGS.dsprites_path)["imgs"]
|
||||
l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
|
||||
|
||||
np.random.seed(1)
|
||||
idx = np.random.permutation(data.shape[0])
|
||||
|
@ -589,52 +743,74 @@ def main():
|
|||
|
||||
# Model 1 will be conditioned on size
|
||||
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
|
||||
weight_size = model_size.construct_weights('context_0')
|
||||
weight_size = model_size.construct_weights("context_0")
|
||||
|
||||
# Model 2 will be conditioned on shape
|
||||
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
|
||||
weight_shape = model_shape.construct_weights('context_1')
|
||||
weight_shape = model_shape.construct_weights("context_1")
|
||||
|
||||
# Model 3 will be conditioned on position
|
||||
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
|
||||
weight_pos = model_pos.construct_weights('context_2')
|
||||
weight_pos = model_pos.construct_weights("context_2")
|
||||
|
||||
# Model 4 will be conditioned on rotation
|
||||
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
|
||||
weight_rot = model_rot.construct_weights('context_3')
|
||||
weight_rot = model_rot.construct_weights("context_3")
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
|
||||
save_path_size = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
|
||||
)
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
|
||||
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
|
||||
if FLAGS.cond_scale:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_size)
|
||||
|
||||
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
|
||||
save_path_shape = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
|
||||
)
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
|
||||
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
|
||||
if FLAGS.cond_shape:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_shape)
|
||||
|
||||
|
||||
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
|
||||
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
|
||||
save_path_pos = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
|
||||
)
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_pos:
|
||||
saver.restore(sess, save_path_pos)
|
||||
|
||||
|
||||
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
|
||||
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
|
||||
save_path_rot = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
|
||||
)
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_rot:
|
||||
|
@ -646,53 +822,57 @@ def main():
|
|||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
|
||||
x_mod = X_NOISE
|
||||
|
||||
kvs = {}
|
||||
kvs['X_NOISE'] = X_NOISE
|
||||
kvs['LABEL_SIZE'] = LABEL_SIZE
|
||||
kvs['LABEL_SHAPE'] = LABEL_SHAPE
|
||||
kvs['LABEL_POS'] = LABEL_POS
|
||||
kvs['LABEL_ROT'] = LABEL_ROT
|
||||
kvs['model_size'] = model_size
|
||||
kvs['model_shape'] = model_shape
|
||||
kvs['model_pos'] = model_pos
|
||||
kvs['model_rot'] = model_rot
|
||||
kvs['weight_size'] = weight_size
|
||||
kvs['weight_shape'] = weight_shape
|
||||
kvs['weight_pos'] = weight_pos
|
||||
kvs['weight_rot'] = weight_rot
|
||||
kvs["X_NOISE"] = X_NOISE
|
||||
kvs["LABEL_SIZE"] = LABEL_SIZE
|
||||
kvs["LABEL_SHAPE"] = LABEL_SHAPE
|
||||
kvs["LABEL_POS"] = LABEL_POS
|
||||
kvs["LABEL_ROT"] = LABEL_ROT
|
||||
kvs["model_size"] = model_size
|
||||
kvs["model_shape"] = model_shape
|
||||
kvs["model_pos"] = model_pos
|
||||
kvs["model_rot"] = model_rot
|
||||
kvs["weight_size"] = weight_size
|
||||
kvs["weight_shape"] = weight_shape
|
||||
kvs["weight_pos"] = weight_pos
|
||||
kvs["weight_rot"] = weight_rot
|
||||
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
|
||||
if FLAGS.task == 'conceptcombine':
|
||||
if FLAGS.task == "conceptcombine":
|
||||
conceptcombine(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'labeldiscover':
|
||||
elif FLAGS.task == "labeldiscover":
|
||||
labeldiscover(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'gentest':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
elif FLAGS.task == "gentest":
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
gentest(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'genbaseline':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
elif FLAGS.task == "genbaseline":
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
if FLAGS.plot_curve:
|
||||
mse_losses = []
|
||||
for frac in [i/10 for i in range(11)]:
|
||||
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
|
||||
for frac in [i / 10 for i in range(11)]:
|
||||
mse_loss = genbaseline(
|
||||
sess, kvs, data, latents, save_exp_dir, frac=frac
|
||||
)
|
||||
mse_losses.append(mse_loss)
|
||||
np.save("mse_baseline_comb.npy", mse_losses)
|
||||
else:
|
||||
genbaseline(sess, kvs, data, latents, save_exp_dir)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
File diff suppressed because it is too large
Load diff
214
EBMs/fid.py
214
EBMs/fid.py
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
|
||||
The FID metric calculates the distance between two distributions of images.
|
||||
Typically, we have summary statistics (mean & covariance matrix) of one
|
||||
|
@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
|
|||
samples respectivly.
|
||||
|
||||
See --help to see further details.
|
||||
'''
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import gzip, pickle
|
||||
import tensorflow as tf
|
||||
from scipy.misc import imread
|
||||
from scipy import linalg
|
||||
import pathlib
|
||||
import urllib
|
||||
import tarfile
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from scipy import linalg
|
||||
from scipy.misc import imread
|
||||
|
||||
MODEL_DIR = "/tmp/imagenet"
|
||||
DATA_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
pool3 = None
|
||||
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
def get_fid_score(images, images_gt):
|
||||
images = np.stack(images, 0)
|
||||
images_gt = np.stack(images_gt, 0)
|
||||
|
@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
|
|||
def create_inception_graph(pth):
|
||||
"""Creates a graph from saved GraphDef file."""
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile( pth, 'rb') as f:
|
||||
with tf.gfile.FastGFile(pth, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString( f.read())
|
||||
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
|
||||
#-------------------------------------------------------------------------------
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# code for handling inception net derived from
|
||||
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
|
||||
def _get_inception_layer(sess):
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer. """
|
||||
layername = 'FID_Inception_Net/pool_3:0'
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer."""
|
||||
layername = "FID_Inception_Net/pool_3:0"
|
||||
pool3 = sess.graph.get_tensor_by_name(layername)
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_activations(images, sess, batch_size=50, verbose=False):
|
||||
|
@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
|
|||
# inception_layer = _get_inception_layer(sess)
|
||||
d0 = images.shape[0]
|
||||
if batch_size > d0:
|
||||
print("warning: batch size is bigger than the data size. setting batch size to data size")
|
||||
print(
|
||||
"warning: batch size is bigger than the data size. setting batch size to data size"
|
||||
)
|
||||
batch_size = d0
|
||||
n_batches = d0//batch_size
|
||||
n_used_imgs = n_batches*batch_size
|
||||
pred_arr = np.empty((n_used_imgs,2048))
|
||||
n_batches = d0 // batch_size
|
||||
n_used_imgs = n_batches * batch_size
|
||||
pred_arr = np.empty((n_used_imgs, 2048))
|
||||
for i in range(n_batches):
|
||||
if verbose:
|
||||
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
|
||||
start = i*batch_size
|
||||
print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
|
||||
start = i * batch_size
|
||||
end = start + batch_size
|
||||
batch = images[start:end]
|
||||
pred = sess.run(pool3, {'ExpandDims:0': batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size,-1)
|
||||
pred = sess.run(pool3, {"ExpandDims:0": batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size, -1)
|
||||
if verbose:
|
||||
print(" done")
|
||||
return pred_arr
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
|
@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
||||
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), "Training and test mean vectors have different lengths"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), "Training and test covariances have different dimensions"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
|
||||
msg = (
|
||||
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
||||
% eps
|
||||
)
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
||||
|
@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
|||
mu = np.mean(act, axis=0)
|
||||
sigma = np.cov(act, rowvar=False)
|
||||
return mu, sigma
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# The following functions aren't needed for calculating the FID
|
||||
# they're just here to make this module work as a stand-alone script
|
||||
# for calculating FID scores
|
||||
#-------------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------------
|
||||
def check_or_download_inception(inception_path):
|
||||
''' Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present. '''
|
||||
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
"""Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present."""
|
||||
INCEPTION_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
if inception_path is None:
|
||||
inception_path = '/tmp'
|
||||
inception_path = "/tmp"
|
||||
inception_path = pathlib.Path(inception_path)
|
||||
model_file = inception_path / 'classify_image_graph_def.pb'
|
||||
model_file = inception_path / "classify_image_graph_def.pb"
|
||||
if not model_file.exists():
|
||||
print("Downloading Inception model")
|
||||
from urllib import request
|
||||
import tarfile
|
||||
from urllib import request
|
||||
|
||||
fn, _ = request.urlretrieve(INCEPTION_URL)
|
||||
with tarfile.open(fn, mode='r') as f:
|
||||
f.extract('classify_image_graph_def.pb', str(model_file.parent))
|
||||
with tarfile.open(fn, mode="r") as f:
|
||||
f.extract("classify_image_graph_def.pb", str(model_file.parent))
|
||||
return str(model_file)
|
||||
|
||||
|
||||
def _handle_path(path, sess):
|
||||
if path.endswith('.npz'):
|
||||
if path.endswith(".npz"):
|
||||
f = np.load(path)
|
||||
m, s = f['mu'][:], f['sigma'][:]
|
||||
m, s = f["mu"][:], f["sigma"][:]
|
||||
f.close()
|
||||
else:
|
||||
path = pathlib.Path(path)
|
||||
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
||||
files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
|
||||
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
|
||||
m, s = calculate_activation_statistics(x, sess)
|
||||
return m, s
|
||||
|
||||
|
||||
def calculate_fid_given_paths(paths, inception_path):
|
||||
''' Calculates the FID of two paths. '''
|
||||
"""Calculates the FID of two paths."""
|
||||
inception_path = check_or_download_inception(inception_path)
|
||||
|
||||
for p in paths:
|
||||
|
@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
|
|||
|
||||
|
||||
def _init_inception():
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split("/")[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write(
|
||||
"\r>> Downloading %s %.1f%%"
|
||||
% (filename, float(count * block_size) / float(total_size) * 100.0)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
|
||||
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(
|
||||
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
|
||||
) as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="")
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
|
||||
|
||||
if pool3 is None:
|
||||
_init_inception()
|
||||
_init_inception()
|
||||
|
|
42
EBMs/hmc.py
42
EBMs/hmc.py
|
@ -1,11 +1,11 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.platform import flags
|
||||
flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
|
||||
|
||||
flags.DEFINE_bool("proposal_debug", False, "Print hmc acceptance raes")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def kinetic_energy(velocity):
|
||||
"""Kinetic energy of the current velocity (assuming a standard Gaussian)
|
||||
(x dot x) / 2
|
||||
|
@ -21,6 +21,7 @@ def kinetic_energy(velocity):
|
|||
"""
|
||||
return 0.5 * tf.square(velocity)
|
||||
|
||||
|
||||
def hamiltonian(position, velocity, energy_function):
|
||||
"""Computes the Hamiltonian of the current position, velocity pair
|
||||
|
||||
|
@ -44,13 +45,12 @@ def hamiltonian(position, velocity, energy_function):
|
|||
"""
|
||||
batch_size = tf.shape(velocity)[0]
|
||||
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
|
||||
return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
|
||||
return tf.squeeze(energy_function(position)) + tf.reduce_sum(
|
||||
kinetic_energy_flat, axis=[1]
|
||||
)
|
||||
|
||||
def leapfrog_step(x0,
|
||||
v0,
|
||||
neg_log_posterior,
|
||||
step_size,
|
||||
num_steps):
|
||||
|
||||
def leapfrog_step(x0, v0, neg_log_posterior, step_size, num_steps):
|
||||
|
||||
# Start by updating the velocity a half-step
|
||||
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
|
||||
|
@ -83,10 +83,8 @@ def leapfrog_step(x0,
|
|||
# return new proposal state
|
||||
return x, v
|
||||
|
||||
def hmc(initial_x,
|
||||
step_size,
|
||||
num_steps,
|
||||
neg_log_posterior):
|
||||
|
||||
def hmc(initial_x, step_size, num_steps, neg_log_posterior):
|
||||
"""Summary
|
||||
|
||||
Parameters
|
||||
|
@ -107,11 +105,13 @@ def hmc(initial_x,
|
|||
"""
|
||||
|
||||
v0 = tf.random_normal(tf.shape(initial_x))
|
||||
x, v = leapfrog_step(initial_x,
|
||||
v0,
|
||||
step_size=step_size,
|
||||
num_steps=num_steps,
|
||||
neg_log_posterior=neg_log_posterior)
|
||||
x, v = leapfrog_step(
|
||||
initial_x,
|
||||
v0,
|
||||
step_size=step_size,
|
||||
num_steps=num_steps,
|
||||
neg_log_posterior=neg_log_posterior,
|
||||
)
|
||||
|
||||
orig = hamiltonian(initial_x, v0, neg_log_posterior)
|
||||
current = hamiltonian(x, v, neg_log_posterior)
|
||||
|
@ -119,10 +119,12 @@ def hmc(initial_x,
|
|||
prob_accept = tf.exp(orig - current)
|
||||
|
||||
if FLAGS.proposal_debug:
|
||||
prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
|
||||
prob_accept = tf.Print(
|
||||
prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]
|
||||
)
|
||||
|
||||
uniform = tf.random_uniform(tf.shape(prob_accept))
|
||||
keep_mask = (prob_accept > uniform)
|
||||
keep_mask = prob_accept > uniform
|
||||
# print(keep_mask.get_shape())
|
||||
|
||||
x_new = tf.where(keep_mask, x, initial_x)
|
||||
|
|
|
@ -1,24 +1,28 @@
|
|||
from models import ResNet128
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from tensorflow.python.platform import flags
|
||||
import tensorflow as tf
|
||||
|
||||
import imageio
|
||||
from utils import optimistic_restore
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from models import ResNet128
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
|
||||
flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
|
||||
flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
|
||||
flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
|
||||
flags.DEFINE_bool('cclass', True, 'conditional models')
|
||||
flags.DEFINE_bool('use_attention', False, 'using attention')
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_integer("num_steps", 200, "num of steps for conditional imagenet sampling")
|
||||
flags.DEFINE_float("step_lr", 180.0, "step size for Langevin dynamics")
|
||||
flags.DEFINE_integer("batch_size", 16, "number of steps to run")
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
|
||||
flags.DEFINE_bool(
|
||||
"spec_norm", True, "whether to use spectral normalization in weights in a model"
|
||||
)
|
||||
flags.DEFINE_bool("cclass", True, "conditional models")
|
||||
flags.DEFINE_bool("use_attention", False, "using attention")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def rescale_im(im):
|
||||
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
||||
|
||||
|
@ -32,12 +36,11 @@ if __name__ == "__main__":
|
|||
weights = model.construct_weights("context_0")
|
||||
|
||||
x_mod = X_NOISE
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
|
||||
mean=0.0,
|
||||
stddev=0.005)
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
||||
|
||||
energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
|
||||
reuse=True, stop_at_grad=False, stop_batch=True)
|
||||
energy_noise = energy_start = model.forward(
|
||||
x_mod, weights, label=LABEL, reuse=True, stop_at_grad=False, stop_batch=True
|
||||
)
|
||||
|
||||
x_grad = tf.gradients(energy_noise, [x_mod])[0]
|
||||
energy_noise_old = energy_noise
|
||||
|
@ -54,20 +57,23 @@ if __name__ == "__main__":
|
|||
saver = loader = tf.train.Saver()
|
||||
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
|
||||
saver.restore(sess, model_file)
|
||||
|
||||
lx = np.random.permutation(1000)[:16]
|
||||
ims = []
|
||||
|
||||
# What to initialize sampling with.
|
||||
# What to initialize sampling with.
|
||||
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
labels = np.eye(1000)[lx]
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
|
||||
ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
|
||||
|
||||
imageio.mimwrite('sample.gif', ims)
|
||||
|
||||
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE: x_mod, LABEL: labels})
|
||||
ims.append(
|
||||
rescale_im(x_mod)
|
||||
.reshape((4, 4, 128, 128, 3))
|
||||
.transpose((0, 2, 1, 3, 4))
|
||||
.reshape((512, 512, 3))
|
||||
)
|
||||
|
||||
imageio.mimwrite("sample.gif", ims)
|
||||
|
|
|
@ -13,14 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Image pre-processing utilities.
|
||||
"""
|
||||
"""Image pre-processing utilities."""
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_DEPTH = 3 # color images
|
||||
|
||||
IMAGE_DEPTH = 3 # color images
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# _R_MEAN = 123.68
|
||||
# _G_MEAN = 116.78
|
||||
|
@ -35,303 +32,318 @@ _RESIZE_MIN = 128
|
|||
|
||||
|
||||
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
|
||||
"""Crops the given image to a random part of the image, and randomly flips.
|
||||
"""Crops the given image to a random part of the image, and randomly flips.
|
||||
|
||||
We use the fused decode_and_crop op, which performs better than the two ops
|
||||
used separately in series, but note that this requires that the image be
|
||||
passed in as an un-decoded string Tensor.
|
||||
We use the fused decode_and_crop op, which performs better than the two ops
|
||||
used separately in series, but note that this requires that the image be
|
||||
passed in as an un-decoded string Tensor.
|
||||
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
|
||||
"""
|
||||
# A large fraction of image datasets contain a human-annotated bounding box
|
||||
# delineating the region of the image containing the object of interest. We
|
||||
# choose to create a new bounding box for the object which is a randomly
|
||||
# distorted version of the human-annotated bounding box that obeys an
|
||||
# allowed range of aspect ratios, sizes and overlap with the human-annotated
|
||||
# bounding box. If no box is supplied, then we assume the bounding box is
|
||||
# the entire image.
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
tf.image.extract_jpeg_shape(image_buffer),
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=[0.75, 1.33],
|
||||
area_range=[0.05, 1.0],
|
||||
max_attempts=100,
|
||||
use_image_if_no_bounding_boxes=True)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
"""
|
||||
# A large fraction of image datasets contain a human-annotated bounding box
|
||||
# delineating the region of the image containing the object of interest. We
|
||||
# choose to create a new bounding box for the object which is a randomly
|
||||
# distorted version of the human-annotated bounding box that obeys an
|
||||
# allowed range of aspect ratios, sizes and overlap with the human-annotated
|
||||
# bounding box. If no box is supplied, then we assume the bounding box is
|
||||
# the entire image.
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
tf.image.extract_jpeg_shape(image_buffer),
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=[0.75, 1.33],
|
||||
area_range=[0.05, 1.0],
|
||||
max_attempts=100,
|
||||
use_image_if_no_bounding_boxes=True,
|
||||
)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
|
||||
# Reassemble the bounding box in the format the crop op requires.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
||||
# Reassemble the bounding box in the format the crop op requires.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
||||
|
||||
# Use the fused decode and crop op here, which is faster than each in series.
|
||||
cropped = tf.image.decode_and_crop_jpeg(
|
||||
image_buffer, crop_window, channels=num_channels)
|
||||
# Use the fused decode and crop op here, which is faster than each in
|
||||
# series.
|
||||
cropped = tf.image.decode_and_crop_jpeg(
|
||||
image_buffer, crop_window, channels=num_channels
|
||||
)
|
||||
|
||||
# Flip to add a little more random distortion in.
|
||||
cropped = tf.image.random_flip_left_right(cropped)
|
||||
return cropped
|
||||
# Flip to add a little more random distortion in.
|
||||
cropped = tf.image.random_flip_left_right(cropped)
|
||||
return cropped
|
||||
|
||||
|
||||
def _central_crop(image, crop_height, crop_width):
|
||||
"""Performs central crops of the given image list.
|
||||
"""Performs central crops of the given image list.
|
||||
|
||||
Args:
|
||||
image: a 3-D image tensor
|
||||
crop_height: the height of the image following the crop.
|
||||
crop_width: the width of the image following the crop.
|
||||
Args:
|
||||
image: a 3-D image tensor
|
||||
crop_height: the height of the image following the crop.
|
||||
crop_width: the width of the image following the crop.
|
||||
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
|
||||
amount_to_be_cropped_h = (height - crop_height)
|
||||
crop_top = amount_to_be_cropped_h // 2
|
||||
amount_to_be_cropped_w = (width - crop_width)
|
||||
crop_left = amount_to_be_cropped_w // 2
|
||||
return tf.slice(
|
||||
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
|
||||
amount_to_be_cropped_h = height - crop_height
|
||||
crop_top = amount_to_be_cropped_h // 2
|
||||
amount_to_be_cropped_w = width - crop_width
|
||||
crop_left = amount_to_be_cropped_w // 2
|
||||
return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
|
||||
|
||||
|
||||
def _mean_image_subtraction(image, means, num_channels):
|
||||
"""Subtracts the given means from each image channel.
|
||||
"""Subtracts the given means from each image channel.
|
||||
|
||||
For example:
|
||||
means = [123.68, 116.779, 103.939]
|
||||
image = _mean_image_subtraction(image, means)
|
||||
For example:
|
||||
means = [123.68, 116.779, 103.939]
|
||||
image = _mean_image_subtraction(image, means)
|
||||
|
||||
Note that the rank of `image` must be known.
|
||||
Note that the rank of `image` must be known.
|
||||
|
||||
Args:
|
||||
image: a tensor of size [height, width, C].
|
||||
means: a C-vector of values to subtract from each channel.
|
||||
num_channels: number of color channels in the image that will be distorted.
|
||||
Args:
|
||||
image: a tensor of size [height, width, C].
|
||||
means: a C-vector of values to subtract from each channel.
|
||||
num_channels: number of color channels in the image that will be distorted.
|
||||
|
||||
Returns:
|
||||
the centered image.
|
||||
Returns:
|
||||
the centered image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rank of `image` is unknown, if `image` has a rank other
|
||||
than three or if the number of channels in `image` doesn't match the
|
||||
number of values in `means`.
|
||||
"""
|
||||
if image.get_shape().ndims != 3:
|
||||
raise ValueError('Input must be of size [height, width, C>0]')
|
||||
Raises:
|
||||
ValueError: If the rank of `image` is unknown, if `image` has a rank other
|
||||
than three or if the number of channels in `image` doesn't match the
|
||||
number of values in `means`.
|
||||
"""
|
||||
if image.get_shape().ndims != 3:
|
||||
raise ValueError("Input must be of size [height, width, C>0]")
|
||||
|
||||
if len(means) != num_channels:
|
||||
raise ValueError('len(means) must match the number of channels')
|
||||
if len(means) != num_channels:
|
||||
raise ValueError("len(means) must match the number of channels")
|
||||
|
||||
# We have a 1-D tensor of means; convert to 3-D.
|
||||
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
|
||||
# We have a 1-D tensor of means; convert to 3-D.
|
||||
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
|
||||
|
||||
return image - means
|
||||
return image - means
|
||||
|
||||
|
||||
def _smallest_size_at_least(height, width, resize_min):
|
||||
"""Computes new shape with the smallest side equal to `smallest_side`.
|
||||
"""Computes new shape with the smallest side equal to `smallest_side`.
|
||||
|
||||
Computes new shape with the smallest side equal to `smallest_side` while
|
||||
preserving the original aspect ratio.
|
||||
Computes new shape with the smallest side equal to `smallest_side` while
|
||||
preserving the original aspect ratio.
|
||||
|
||||
Args:
|
||||
height: an int32 scalar tensor indicating the current height.
|
||||
width: an int32 scalar tensor indicating the current width.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
Args:
|
||||
height: an int32 scalar tensor indicating the current height.
|
||||
width: an int32 scalar tensor indicating the current width.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
|
||||
Returns:
|
||||
new_height: an int32 scalar tensor indicating the new height.
|
||||
new_width: an int32 scalar tensor indicating the new width.
|
||||
"""
|
||||
resize_min = tf.cast(resize_min, tf.float32)
|
||||
Returns:
|
||||
new_height: an int32 scalar tensor indicating the new height.
|
||||
new_width: an int32 scalar tensor indicating the new width.
|
||||
"""
|
||||
resize_min = tf.cast(resize_min, tf.float32)
|
||||
|
||||
# Convert to floats to make subsequent calculations go smoothly.
|
||||
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
|
||||
# Convert to floats to make subsequent calculations go smoothly.
|
||||
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
|
||||
|
||||
smaller_dim = tf.minimum(height, width)
|
||||
scale_ratio = resize_min / smaller_dim
|
||||
smaller_dim = tf.minimum(height, width)
|
||||
scale_ratio = resize_min / smaller_dim
|
||||
|
||||
# Convert back to ints to make heights and widths that TF ops will accept.
|
||||
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
|
||||
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
|
||||
# Convert back to ints to make heights and widths that TF ops will accept.
|
||||
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
|
||||
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
|
||||
|
||||
return new_height, new_width
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
def _aspect_preserving_resize(image, resize_min):
|
||||
"""Resize images preserving the original aspect ratio.
|
||||
"""Resize images preserving the original aspect ratio.
|
||||
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
|
||||
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
|
||||
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
|
||||
|
||||
return _resize_image(image, new_height, new_width)
|
||||
return _resize_image(image, new_height, new_width)
|
||||
|
||||
|
||||
def _resize_image(image, height, width):
|
||||
"""Simple wrapper around tf.resize_images.
|
||||
"""Simple wrapper around tf.resize_images.
|
||||
|
||||
This is primarily to make sure we use the same `ResizeMethod` and other
|
||||
details each time.
|
||||
This is primarily to make sure we use the same `ResizeMethod` and other
|
||||
details each time.
|
||||
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
height: The target height for the resized image.
|
||||
width: The target width for the resized image.
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
height: The target height for the resized image.
|
||||
width: The target width for the resized image.
|
||||
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image. The first two
|
||||
dimensions have the shape [height, width].
|
||||
"""
|
||||
return tf.image.resize_images(
|
||||
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
|
||||
align_corners=False)
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image. The first two
|
||||
dimensions have the shape [height, width].
|
||||
"""
|
||||
return tf.image.resize_images(
|
||||
image,
|
||||
[height, width],
|
||||
method=tf.image.ResizeMethod.BILINEAR,
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_image(image_buffer, bbox, output_height, output_width,
|
||||
num_channels, is_training=False):
|
||||
"""Preprocesses the given image.
|
||||
def preprocess_image(
|
||||
image_buffer, bbox, output_height, output_width, num_channels, is_training=False
|
||||
):
|
||||
"""Preprocesses the given image.
|
||||
|
||||
Preprocessing includes decoding, cropping, and resizing for both training
|
||||
and eval images. Training preprocessing, however, introduces some random
|
||||
distortion of the image to improve accuracy.
|
||||
Preprocessing includes decoding, cropping, and resizing for both training
|
||||
and eval images. Training preprocessing, however, introduces some random
|
||||
distortion of the image to improve accuracy.
|
||||
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
output_height: The height of the image after preprocessing.
|
||||
output_width: The width of the image after preprocessing.
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
is_training: `True` if we're preprocessing the image for training and
|
||||
`False` otherwise.
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
output_height: The height of the image after preprocessing.
|
||||
output_width: The width of the image after preprocessing.
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
is_training: `True` if we're preprocessing the image for training and
|
||||
`False` otherwise.
|
||||
|
||||
Returns:
|
||||
A preprocessed image.
|
||||
"""
|
||||
if is_training:
|
||||
# For training, we want to randomize some of the distortions.
|
||||
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
|
||||
image = _resize_image(image, output_height, output_width)
|
||||
else:
|
||||
# For validation, we want to decode, resize, then just crop the middle.
|
||||
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
|
||||
image = _aspect_preserving_resize(image, _RESIZE_MIN)
|
||||
print(image)
|
||||
image = _central_crop(image, output_height, output_width)
|
||||
Returns:
|
||||
A preprocessed image.
|
||||
"""
|
||||
if is_training:
|
||||
# For training, we want to randomize some of the distortions.
|
||||
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
|
||||
image = _resize_image(image, output_height, output_width)
|
||||
else:
|
||||
# For validation, we want to decode, resize, then just crop the middle.
|
||||
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
|
||||
image = _aspect_preserving_resize(image, _RESIZE_MIN)
|
||||
print(image)
|
||||
image = _central_crop(image, output_height, output_width)
|
||||
|
||||
image.set_shape([output_height, output_width, num_channels])
|
||||
image.set_shape([output_height, output_width, num_channels])
|
||||
|
||||
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
|
||||
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
|
||||
|
||||
|
||||
def parse_example_proto(example_serialized):
|
||||
"""Parses an Example proto containing a training example of an image.
|
||||
"""Parses an Example proto containing a training example of an image.
|
||||
|
||||
The output of the build_image_data.py image preprocessing script is a dataset
|
||||
containing serialized Example protocol buffers. Each Example proto contains
|
||||
the following fields:
|
||||
The output of the build_image_data.py image preprocessing script is a dataset
|
||||
containing serialized Example protocol buffers. Each Example proto contains
|
||||
the following fields:
|
||||
|
||||
image/height: 462
|
||||
image/width: 581
|
||||
image/colorspace: 'RGB'
|
||||
image/channels: 3
|
||||
image/class/label: 615
|
||||
image/class/synset: 'n03623198'
|
||||
image/class/text: 'knee pad'
|
||||
image/object/bbox/xmin: 0.1
|
||||
image/object/bbox/xmax: 0.9
|
||||
image/object/bbox/ymin: 0.2
|
||||
image/object/bbox/ymax: 0.6
|
||||
image/object/bbox/label: 615
|
||||
image/format: 'JPEG'
|
||||
image/filename: 'ILSVRC2012_val_00041207.JPEG'
|
||||
image/encoded: <JPEG encoded string>
|
||||
image/height: 462
|
||||
image/width: 581
|
||||
image/colorspace: 'RGB'
|
||||
image/channels: 3
|
||||
image/class/label: 615
|
||||
image/class/synset: 'n03623198'
|
||||
image/class/text: 'knee pad'
|
||||
image/object/bbox/xmin: 0.1
|
||||
image/object/bbox/xmax: 0.9
|
||||
image/object/bbox/ymin: 0.2
|
||||
image/object/bbox/ymax: 0.6
|
||||
image/object/bbox/label: 615
|
||||
image/format: 'JPEG'
|
||||
image/filename: 'ILSVRC2012_val_00041207.JPEG'
|
||||
image/encoded: <JPEG encoded string>
|
||||
|
||||
Args:
|
||||
example_serialized: scalar Tensor tf.string containing a serialized
|
||||
Example protocol buffer.
|
||||
Args:
|
||||
example_serialized: scalar Tensor tf.string containing a serialized
|
||||
Example protocol buffer.
|
||||
|
||||
Returns:
|
||||
image_buffer: Tensor tf.string containing the contents of a JPEG file.
|
||||
label: Tensor tf.int32 containing the label.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
text: Tensor tf.string containing the human-readable label.
|
||||
"""
|
||||
# Dense features in Example proto.
|
||||
feature_map = {
|
||||
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
|
||||
default_value=''),
|
||||
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
|
||||
default_value=-1),
|
||||
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
|
||||
default_value=''),
|
||||
}
|
||||
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
|
||||
# Sparse features in Example proto.
|
||||
feature_map.update(
|
||||
{k: sparse_float32 for k in ['image/object/bbox/xmin',
|
||||
'image/object/bbox/ymin',
|
||||
'image/object/bbox/xmax',
|
||||
'image/object/bbox/ymax']})
|
||||
Returns:
|
||||
image_buffer: Tensor tf.string containing the contents of a JPEG file.
|
||||
label: Tensor tf.int32 containing the label.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
text: Tensor tf.string containing the human-readable label.
|
||||
"""
|
||||
# Dense features in Example proto.
|
||||
feature_map = {
|
||||
"image/encoded": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
|
||||
"image/class/label": tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1),
|
||||
"image/class/text": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
|
||||
}
|
||||
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
|
||||
# Sparse features in Example proto.
|
||||
feature_map.update(
|
||||
{
|
||||
k: sparse_float32
|
||||
for k in [
|
||||
"image/object/bbox/xmin",
|
||||
"image/object/bbox/ymin",
|
||||
"image/object/bbox/xmax",
|
||||
"image/object/bbox/ymax",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
features = tf.parse_single_example(example_serialized, feature_map)
|
||||
label = tf.cast(features['image/class/label'], dtype=tf.int32)
|
||||
features = tf.parse_single_example(example_serialized, feature_map)
|
||||
label = tf.cast(features["image/class/label"], dtype=tf.int32)
|
||||
|
||||
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
|
||||
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
|
||||
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
|
||||
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
|
||||
xmin = tf.expand_dims(features["image/object/bbox/xmin"].values, 0)
|
||||
ymin = tf.expand_dims(features["image/object/bbox/ymin"].values, 0)
|
||||
xmax = tf.expand_dims(features["image/object/bbox/xmax"].values, 0)
|
||||
ymax = tf.expand_dims(features["image/object/bbox/ymax"].values, 0)
|
||||
|
||||
# Note that we impose an ordering of (y, x) just to make life difficult.
|
||||
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
|
||||
# Note that we impose an ordering of (y, x) just to make life difficult.
|
||||
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
|
||||
|
||||
# Force the variable number of bounding boxes into the shape
|
||||
# [1, num_boxes, coords].
|
||||
bbox = tf.expand_dims(bbox, 0)
|
||||
bbox = tf.transpose(bbox, [0, 2, 1])
|
||||
# Force the variable number of bounding boxes into the shape
|
||||
# [1, num_boxes, coords].
|
||||
bbox = tf.expand_dims(bbox, 0)
|
||||
bbox = tf.transpose(bbox, [0, 2, 1])
|
||||
|
||||
return features['image/encoded'], label, bbox, features['image/class/text']
|
||||
return features["image/encoded"], label, bbox, features["image/class/text"]
|
||||
|
||||
|
||||
class ImagenetPreprocessor:
|
||||
def __init__(self, image_size, dtype, train):
|
||||
self.image_size = image_size
|
||||
self.dtype = dtype
|
||||
self.train = train
|
||||
def __init__(self, image_size, dtype, train):
|
||||
self.image_size = image_size
|
||||
self.dtype = dtype
|
||||
self.train = train
|
||||
|
||||
def preprocess(self, image_buffer, bbox):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
|
||||
return tf.cast(image, self.dtype)
|
||||
|
||||
def parse_and_preprocess(self, value):
|
||||
image_buffer, label_index, bbox, _ = parse_example_proto(value)
|
||||
image = self.preprocess(image_buffer, bbox)
|
||||
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
|
||||
return label_index, image
|
||||
def preprocess(self, image_buffer, bbox):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
image = preprocess_image(
|
||||
image_buffer,
|
||||
bbox,
|
||||
self.image_size,
|
||||
self.image_size,
|
||||
IMAGE_DEPTH,
|
||||
is_training=self.train,
|
||||
)
|
||||
return tf.cast(image, self.dtype)
|
||||
|
||||
def parse_and_preprocess(self, value):
|
||||
image_buffer, label_index, bbox, _ = parse_example_proto(value)
|
||||
image = self.preprocess(image_buffer, bbox)
|
||||
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
|
||||
return label_index, image
|
||||
|
|
|
@ -1,105 +1,112 @@
|
|||
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# Code derived from
|
||||
# tensorflow/tensorflow/models/image/imagenet/classify_image.py
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import math
|
||||
import os.path
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from six.moves import urllib
|
||||
import tensorflow as tf
|
||||
import glob
|
||||
import scipy.misc
|
||||
import math
|
||||
import sys
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from six.moves import urllib
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
MODEL_DIR = "/tmp/imagenet"
|
||||
DATA_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
softmax = None
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
sess = tf.Session(config=config)
|
||||
|
||||
# Call this function with list of images. Each of elements should be a
|
||||
|
||||
# Call this function with list of images. Each of elements should be a
|
||||
# numpy array with values ranging from 0 to 255.
|
||||
def get_inception_score(images, splits=10):
|
||||
# For convenience
|
||||
if len(images[0].shape) != 3:
|
||||
return 0, 0
|
||||
# For convenience
|
||||
if len(images[0].shape) != 3:
|
||||
return 0, 0
|
||||
|
||||
# Bypassing all the assertions so that we don't end prematuraly'
|
||||
# assert(type(images) == list)
|
||||
# assert(type(images[0]) == np.ndarray)
|
||||
# assert(len(images[0].shape) == 3)
|
||||
# assert(np.max(images[0]) > 10)
|
||||
# assert(np.min(images[0]) >= 0.0)
|
||||
inps = []
|
||||
for img in images:
|
||||
img = img.astype(np.float32)
|
||||
inps.append(np.expand_dims(img, 0))
|
||||
bs = 1
|
||||
preds = []
|
||||
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
|
||||
for i in range(n_batches):
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
|
||||
inp = np.concatenate(inp, 0)
|
||||
pred = sess.run(softmax, {"ExpandDims:0": inp})
|
||||
preds.append(pred)
|
||||
preds = np.concatenate(preds, 0)
|
||||
scores = []
|
||||
for i in range(splits):
|
||||
part = preds[
|
||||
(i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
|
||||
]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return np.mean(scores), np.std(scores)
|
||||
|
||||
# Bypassing all the assertions so that we don't end prematuraly'
|
||||
# assert(type(images) == list)
|
||||
# assert(type(images[0]) == np.ndarray)
|
||||
# assert(len(images[0].shape) == 3)
|
||||
# assert(np.max(images[0]) > 10)
|
||||
# assert(np.min(images[0]) >= 0.0)
|
||||
inps = []
|
||||
for img in images:
|
||||
img = img.astype(np.float32)
|
||||
inps.append(np.expand_dims(img, 0))
|
||||
bs = 1
|
||||
preds = []
|
||||
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
|
||||
for i in range(n_batches):
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
|
||||
inp = np.concatenate(inp, 0)
|
||||
pred = sess.run(softmax, {'ExpandDims:0': inp})
|
||||
preds.append(pred)
|
||||
preds = np.concatenate(preds, 0)
|
||||
scores = []
|
||||
for i in range(splits):
|
||||
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return np.mean(scores), np.std(scores)
|
||||
|
||||
# This function is called automatically.
|
||||
def _init_inception():
|
||||
global softmax
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.set_shape(tf.TensorShape(new_shape))
|
||||
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
|
||||
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
|
||||
softmax = tf.nn.softmax(logits)
|
||||
global softmax
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split("/")[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write(
|
||||
"\r>> Downloading %s %.1f%%"
|
||||
% (filename, float(count * block_size) / float(total_size) * 100.0)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
|
||||
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(
|
||||
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
|
||||
) as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="")
|
||||
# Works with an arbitrary minibatch size.
|
||||
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.set_shape(tf.TensorShape(new_shape))
|
||||
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
|
||||
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
|
||||
softmax = tf.nn.softmax(logits)
|
||||
|
||||
|
||||
if softmax is None:
|
||||
_init_inception()
|
||||
_init_inception()
|
||||
|
|
1327
EBMs/models.py
1327
EBMs/models.py
File diff suppressed because it is too large
Load diff
|
@ -1,53 +1,56 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tensorflow.python.platform import flags
|
||||
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
|
||||
import os.path as osp
|
||||
import os
|
||||
from utils import optimistic_restore, remap_restore, optimistic_remap_restore
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from scipy.misc import imsave
|
||||
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines.common.tf_util import initialize
|
||||
from data import Cifar10, Imagenet, TFImagenetLoader
|
||||
from fid import get_fid_score
|
||||
from inception import get_inception_score
|
||||
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
|
||||
from scipy.misc import imsave
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from utils import optimistic_remap_restore
|
||||
|
||||
hvd.init()
|
||||
|
||||
from inception import get_inception_score
|
||||
from fid import get_fid_score
|
||||
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_bool('cclass', False, 'whether to condition on class')
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_bool("cclass", False, "whether to condition on class")
|
||||
|
||||
# Architecture settings
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
|
||||
flags.DEFINE_integer('batch_size', 512, 'batch size')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
|
||||
flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
|
||||
flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
|
||||
flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
|
||||
flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
|
||||
flags.DEFINE_integer('idx', 0, 'save index')
|
||||
flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
|
||||
flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
|
||||
flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
|
||||
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('single', False, 'single ')
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
|
||||
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_float("step_lr", 10.0, "Size of steps for gradient descent")
|
||||
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
||||
flags.DEFINE_float("proj_norm", 0.05, "Maximum change of input images")
|
||||
flags.DEFINE_integer("batch_size", 512, "batch size")
|
||||
flags.DEFINE_integer("resume_iter", -1, "resume iteration")
|
||||
flags.DEFINE_integer("ensemble", 10, "number of ensembles")
|
||||
flags.DEFINE_integer("im_number", 50000, "number of ensembles")
|
||||
flags.DEFINE_integer("repeat_scale", 100, "number of repeat iterations")
|
||||
flags.DEFINE_float("noise_scale", 0.005, "amount of noise to output")
|
||||
flags.DEFINE_integer("idx", 0, "save index")
|
||||
flags.DEFINE_integer("nomix", 10, "number of intervals to stop mixing")
|
||||
flags.DEFINE_bool("scaled", True, "whether to scale noise added")
|
||||
flags.DEFINE_bool("large_model", False, "whether to use a small or large model")
|
||||
flags.DEFINE_bool("larger_model", False, "Whether to use a large model")
|
||||
flags.DEFINE_bool("wider_model", False, "Whether to use a large model")
|
||||
flags.DEFINE_bool("single", False, "single ")
|
||||
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
||||
flags.DEFINE_string("dataset", "cifar10", "cifar10 or imagenet or imagenetfull")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class InceptionReplayBuffer(object):
|
||||
def __init__(self, size):
|
||||
"""Create Replay buffer.
|
||||
|
@ -72,14 +75,16 @@ class InceptionReplayBuffer(object):
|
|||
self._label_storage.extend(list(labels))
|
||||
else:
|
||||
if batch_size + self._next_idx < self._maxsize:
|
||||
self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
|
||||
self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
|
||||
self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
|
||||
self._label_storage[self._next_idx : self._next_idx + batch_size] = (
|
||||
list(labels)
|
||||
)
|
||||
else:
|
||||
split_idx = self._maxsize - self._next_idx
|
||||
self._storage[self._next_idx:] = list(ims)[:split_idx]
|
||||
self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
|
||||
self._label_storage[self._next_idx:] = list(labels)[:split_idx]
|
||||
self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
|
||||
self._storage[self._next_idx :] = list(ims)[:split_idx]
|
||||
self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
|
||||
self._label_storage[self._next_idx :] = list(labels)[:split_idx]
|
||||
self._label_storage[: batch_size - split_idx] = list(labels)[split_idx:]
|
||||
|
||||
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
|
||||
|
||||
|
@ -123,12 +128,13 @@ class InceptionReplayBuffer(object):
|
|||
def rescale_im(im):
|
||||
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def compute_inception(sess, target_vars):
|
||||
X_START = target_vars['X_START']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
X_finals = target_vars['X_finals']
|
||||
NOISE_SCALE = target_vars['NOISE_SCALE']
|
||||
energy_noise = target_vars['energy_noise']
|
||||
X_START = target_vars["X_START"]
|
||||
Y_GT = target_vars["Y_GT"]
|
||||
X_finals = target_vars["X_finals"]
|
||||
NOISE_SCALE = target_vars["NOISE_SCALE"]
|
||||
energy_noise = target_vars["energy_noise"]
|
||||
|
||||
size = FLAGS.im_number
|
||||
num_steps = size // 1000
|
||||
|
@ -136,16 +142,21 @@ def compute_inception(sess, target_vars):
|
|||
images = []
|
||||
test_ims = []
|
||||
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
test_dataset = Cifar10(full=True, noise=False)
|
||||
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
|
||||
test_dataset = Imagenet(train=False)
|
||||
|
||||
if FLAGS.dataset != "imagenetfull":
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
|
||||
test_dataloader = TFImagenetLoader("test", FLAGS.batch_size, 0, 1)
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
data = data.numpy()
|
||||
|
@ -155,7 +166,6 @@ def compute_inception(sess, target_vars):
|
|||
test_ims = test_ims[:60000]
|
||||
break
|
||||
|
||||
|
||||
# n = min(len(images), len(test_ims))
|
||||
print(len(test_ims))
|
||||
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
|
||||
|
@ -187,12 +197,14 @@ def compute_inception(sess, target_vars):
|
|||
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
||||
label = np.random.randint(0, classes, (FLAGS.batch_size))
|
||||
label = identity[label]
|
||||
x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
|
||||
x_new = sess.run(
|
||||
[x_final], {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale}
|
||||
)[0]
|
||||
data_buffer.add(x_new, label)
|
||||
else:
|
||||
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
|
||||
keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
|
||||
label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
|
||||
keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99
|
||||
label_keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9
|
||||
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
|
||||
label_corrupt = identity[label_corrupt]
|
||||
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
||||
|
@ -203,7 +215,10 @@ def compute_inception(sess, target_vars):
|
|||
# else:
|
||||
# noise_scale = [0.7]
|
||||
|
||||
x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
|
||||
x_new, e_noise = sess.run(
|
||||
[x_final, energy_noise],
|
||||
{X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale},
|
||||
)
|
||||
data_buffer.set_elms(idx, x_new, label)
|
||||
|
||||
if FLAGS.im_number != 50000:
|
||||
|
@ -216,14 +231,22 @@ def compute_inception(sess, target_vars):
|
|||
|
||||
images.extend(list(ims))
|
||||
|
||||
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
|
||||
saveim = osp.join("sandbox_cachedir", FLAGS.exp, "test{}.png".format(FLAGS.idx))
|
||||
|
||||
ims = ims[:100]
|
||||
|
||||
if FLAGS.dataset != "imagenetfull":
|
||||
im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
|
||||
im_panel = (
|
||||
ims.reshape((10, 10, 32, 32, 3))
|
||||
.transpose((0, 2, 1, 3, 4))
|
||||
.reshape((320, 320, 3))
|
||||
)
|
||||
else:
|
||||
im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
|
||||
im_panel = (
|
||||
ims.reshape((10, 10, 128, 128, 3))
|
||||
.transpose((0, 2, 1, 3, 4))
|
||||
.reshape((1280, 1280, 3))
|
||||
)
|
||||
imsave(saveim, im_panel)
|
||||
|
||||
print("Saved image!!!!")
|
||||
|
@ -237,8 +260,6 @@ def compute_inception(sess, target_vars):
|
|||
print("FID of score {}".format(fid))
|
||||
|
||||
|
||||
|
||||
|
||||
def main(model_list):
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
|
@ -259,45 +280,55 @@ def main(model_list):
|
|||
weights = []
|
||||
|
||||
for i, model_num in enumerate(model_list):
|
||||
weight = model.construct_weights('context_{}'.format(i))
|
||||
weight = model.construct_weights("context_{}".format(i))
|
||||
initialize()
|
||||
save_file = osp.join(logdir, 'model_{}'.format(model_num))
|
||||
save_file = osp.join(logdir, "model_{}".format(model_num))
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
|
||||
v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(i)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(i), "context_0")[:-2]): v
|
||||
for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
try:
|
||||
saver.restore(sess, save_file)
|
||||
except:
|
||||
except BaseException:
|
||||
optimistic_remap_restore(sess, save_file, i)
|
||||
weights.append(weight)
|
||||
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
|
||||
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
||||
else:
|
||||
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
|
||||
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
|
||||
Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
||||
else:
|
||||
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
|
||||
Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
|
||||
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
|
||||
|
||||
X_finals = []
|
||||
|
||||
|
||||
# Seperate loops
|
||||
for weight in weights:
|
||||
X = X_START
|
||||
|
||||
steps = tf.constant(0)
|
||||
c = lambda i, x: tf.less(i, FLAGS.num_steps)
|
||||
|
||||
def c(i, x):
|
||||
return tf.less(i, FLAGS.num_steps)
|
||||
|
||||
def langevin_step(counter, X):
|
||||
scale_rate = 1
|
||||
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
|
||||
X = X + tf.random_normal(
|
||||
tf.shape(X),
|
||||
mean=0.0,
|
||||
stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE,
|
||||
)
|
||||
|
||||
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
@ -305,7 +336,7 @@ def main(model_list):
|
|||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad * scale_rate
|
||||
X = X - FLAGS.step_lr * x_grad * scale_rate
|
||||
X = tf.clip_by_value(X, 0, 1)
|
||||
|
||||
counter = counter + 1
|
||||
|
@ -318,16 +349,16 @@ def main(model_list):
|
|||
X_finals.append(X_final)
|
||||
|
||||
target_vars = {}
|
||||
target_vars['X_START'] = X_START
|
||||
target_vars['Y_GT'] = Y_GT
|
||||
target_vars['X_finals'] = X_finals
|
||||
target_vars['NOISE_SCALE'] = NOISE_SCALE
|
||||
target_vars['energy_noise'] = energy_noise
|
||||
target_vars["X_START"] = X_START
|
||||
target_vars["Y_GT"] = Y_GT
|
||||
target_vars["X_finals"] = X_finals
|
||||
target_vars["NOISE_SCALE"] = NOISE_SCALE
|
||||
target_vars["energy_noise"] = energy_noise
|
||||
|
||||
compute_inception(sess, target_vars)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# model_list = [117000, 116700]
|
||||
model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
|
||||
model_list = [FLAGS.resume_iter - 300 * i for i in range(FLAGS.ensemble)]
|
||||
main(model_list)
|
||||
|
|
711
EBMs/train.py
711
EBMs/train.py
File diff suppressed because it is too large
Load diff
737
EBMs/utils.py
737
EBMs/utils.py
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue